LVQ.java
/*
* MIT License
*
* Copyright (c) 2009-2016 Ignacio Calderon <https://github.com/kronenthaler>
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
package libai.nn.supervised;
import libai.common.Precondition;
import libai.common.Shuffler;
import libai.common.matrix.Column;
import libai.common.matrix.Matrix;
import libai.common.matrix.Row;
import java.util.Random;
/**
* Learning Vector Quantization or LVQ. Is an hybrid neural network with 3
* layers (1-input, 1-hidden, 1-output). The first set of weights are trained
* using a competitive approach, and the second set of weights are trained using
* a supervised approach. Therefore the first steps are taken from the
* competitive network. This network was proposed by Teuvo Kohonen as
* alternative to the standard competitive. networks.
*
* @author kronenthaler
*/
public class LVQ extends SupervisedLearning {
private static final long serialVersionUID = 6603129562167746698L;
protected Matrix W;
protected Matrix W2;
protected int ins, outs;
protected int subclasses;
/**
* Constructor. Number of inputs, number of subclasses and number of
* outputs.
*
* @param in Number of input to the network.
* @param subclass Number of subclasses for output class. Greater
* subdivision provides better classification.
* @param out Number of outputs for the network
*/
public LVQ(int in, int subclass, int out) {
this(in, subclass, out, getDefaultRandomGenerator());
}
/**
* Constructor. Number of inputs, number of subclasses and number of
* outputs.
*
* @param in Number of input to the network.
* @param subclass Number of subclasses for output class. Greater
* subdivision provides better classification.
* @param out Number of outputs for the network
* @param rand Random generator used for creating matrices
*/
public LVQ(int in, int subclass, int out, Random rand) {
super(rand);
ins = in;
subclasses = subclass;
outs = out;
W = new Matrix(subclasses * outs, ins);
W2 = new Matrix(outs, subclasses * outs);
W.fill(true, random);
W2.setValue(0);
//fill W2 with 1 per row
int j = 0;
int k = 0;
for (int i = 0; i < W2.getColumns(); i++) {
W2.position(j, i, 1);
if (k++ == subclasses - 1) {
j++;
k = 0;
}
}
}
/**
* Train the network using a hybrid scheme. Uses the "winner takes all"
* rule.
*
* @param patterns The patterns to be learned.
* @param answers The expected answers.
* @param alpha The learning rate.
* @param epochs The maximum number of iterations
* @param offset The first pattern position
* @param length How many patterns will be used.
* @param minerror The minimal error expected.
*/
@Override
public void train(Column[] patterns, Column[] answers, double alpha, int epochs, int offset, int length, double minerror) {
validatePreconditions(patterns, answers, epochs, offset, length, minerror);
Matrix[] patternsT = new Matrix[length];
for (int i = 0; i < length; i++) {
patternsT[i] = patterns[i + offset].transpose();
}
double error = 1;
Shuffler shuffler = new Shuffler(length, this.random);
initializeProgressBar(epochs);
Row r = new Row(ins);
Row row = new Row(W.getColumns());
for (int currentEpoch = 0; currentEpoch < epochs && error > minerror; currentEpoch++) {
//shuffle patterns
int[] sort = shuffler.shuffle();
for (int i = 0; i < length; i++) {
//calculate the distance of each pattern to each neuron (rows in W), keep the winner
int winnerOut = -1;
int winnerT = -1;
int winner = simulateNoChange(patterns[sort[i] + offset]);
//find the row with the value 1 in the column winner of W2
for (int j = 0; j < W2.getRows(); j++) {
if (W2.position(j, winner) == 1) {
winnerOut = j;
}
if (answers[sort[i] + offset].position(j, 0) == 1) {
winnerT = j;
}
}
//Ww = Ww +/- alpha . (p - Ww); //w is the row of winner neuron
patternsT[sort[i]].copy(r);
row.setRow(0, W.getRow(winner));
r.subtract(row, r);
r.multiply((winnerT == winnerOut) ? alpha : -alpha, r); //if winner in T == winner in out + else -
row.add(r, r);
W.setRow(winner, r.getRow(0));
}
error = error(patterns, answers, offset, length);
if (plotter != null) {
plotter.setError(currentEpoch, error);
}
if (progress != null) {
progress.setValue(currentEpoch);
}
}
if (progress != null) {
progress.setValue(progress.getMaximum());
}
}
@Override
public Column simulate(Column pattern) {
Column ret = new Column(outs);
Column layer1 = new Column(W.getRows());
simulate(pattern, layer1);
W2.multiply(layer1, ret);
return ret;
}
/**
* Calculate the output for the <code>pattern</code> and left the result in
* <code>result</code>. The result will be a row matrix fill with 0 except
* for the winner position.
*
* @param pattern Pattern to use as input.
* @param result The output for the input.
*/
@Override
public void simulate(Column pattern, Column result) {
int winner = simulateNoChange(pattern);
result.setValue(0);
result.position(winner, 0, 1);
}
protected int simulateNoChange(Matrix pattern) {
double[] row;
double d = Double.MAX_VALUE;
int winner = -1;
for (int j = 0; j < W.getRows(); j++) {
row = W.getRow(j);
double dist = euclideanDistance2(pattern.getCol(0), row);
if (dist < d) {
d = dist;
winner = j;
}
}
return winner;
}
/**
* Calculates the number of incorrect answers over the total.
*
* @param patterns The array with the patterns to test
* @param answers The array with the expected answers for the patterns.
* @param offset The initial position inside the array.
* @param length How many patterns must be taken from the offset.
* @return The relation between the incorrect answers and the total number
* of answers.
*/
@Override
public double error(Column[] patterns, Column[] answers, int offset, int length) {
Precondition.check(patterns.length == answers.length, "There must be the same amount of patterns and answers");
Precondition.check(offset >= 0 && offset < patterns.length, "offset must be in the interval [0, %d), found, %d", patterns.length, offset);
Precondition.check(length >= 0 && length <= patterns.length - offset, "length must be in the interval (0, %d], found, %d", patterns.length - offset, length);
//relation between correct answers and total answers
int correct = 0;
Column ret1 = new Column(W2.getColumns());
Column ret = new Column(outs);
for (int i = 0; i < length; i++) {
simulate(patterns[i + offset], ret1);
W2.multiply(ret1, ret);
if (ret.equals(answers[i + offset])) {
correct++;
}
}
return (length - correct) / (double) length;
}
public Matrix[] getWeights() {
return new Matrix[]{
new Matrix(W),
new Matrix(W2)
};
}
}