Perceptron.java
/*
* MIT License
*
* Copyright (c) 2009-2016 Ignacio Calderon <https://github.com/kronenthaler>
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
package libai.nn.supervised;
import libai.common.Shuffler;
import libai.common.matrix.Column;
import libai.common.matrix.Matrix;
import libai.common.functions.Sign;
import java.util.Random;
/**
* Perceptron is the first trainable neural network proposed. The network is
* formed by one matrix (Weights) and one vector (Bias). The output for the
* network is calculated by O = sign(W * pattern + b).
*
* @author kronenthaler
*/
public class Perceptron extends SupervisedLearning {
private static final long serialVersionUID = 2795822735956649552L;
protected static final Sign signum = new Sign();
protected Matrix W;
protected Column b;
protected final int ins, outs;
/**
* Constructor.
*
* @param in Number of inputs for the network = number of elements in the
* patterns.
* @param out Number of outputs for the network.
*/
public Perceptron(int in, int out) {
this(in, out, getDefaultRandomGenerator());
}
/**
* Constructor.
*
* @param in Number of inputs for the network = number of elements in the
* patterns.
* @param out Number of outputs for the network.
* @param rand Random generator used for creating matrices
*/
public Perceptron(int in, int out, Random rand) {
super(rand);
ins = in;
outs = out;
W = new Matrix(outs, ins);
b = new Column(out);
W.fill(true, random);
b.fill(true, random);
}
/**
* Train the perceptron using the standard update rule: <br>
* W = W + alpha.e.pattern^t<br>
* b = b + alpha.e
*
* @param patterns The patterns to be learned.
* @param answers The expected answers.
* @param alpha The learning rate.
* @param epochs The maximum number of iterations
* @param offset The first pattern position
* @param length How many patterns will be used.
* @param minerror The minimal error expected.
*/
@Override
public void train(Column[] patterns, Column[] answers, double alpha, int epochs, int offset, int length, double minerror) {
validatePreconditions(patterns, answers, epochs, offset, length, minerror);
Matrix[] patternsT = new Matrix[length];
for (int i = 0; i < length; i++) {
patternsT[i] = patterns[i + offset].transpose();
}
Column Y = new Column(outs);
Column E = new Column(outs);
Matrix aux = new Matrix(outs, ins);
double error = 1;
Shuffler shuffler = new Shuffler(length, this.random);
initializeProgressBar(epochs);
for (int currentEpoch = 0; currentEpoch < epochs && error > minerror; currentEpoch++) {
//shuffle patterns
int[] sort = shuffler.shuffle();
for (int i = 0; i < length; i++) {
//F(wx+b)
simulate(patterns[sort[i] + offset], Y);
//e=t-y
answers[sort[i] + offset].subtract(Y, E); //error
//alpha*e.p^t
E.multiply(alpha, E);
E.multiply(patternsT[sort[i]], aux);
W.add(aux, W);//W+(alpha*e.p^t)
b.add(E, b); //b+(alpha*e)
}
error = error(patterns, answers, offset, length);
if (plotter != null) {
plotter.setError(currentEpoch, error);
}
if (progress != null) {
progress.setValue(currentEpoch);
}
}
if (progress != null) {
progress.setValue(progress.getMaximum());
}
}
@Override
public Column simulate(Column p) {
Column Y = new Column(outs);
simulate(p, Y);
return Y;
}
/**
* Calculate the output for the pattern and left the result on result.
* result = signum(W * pattern + b)
*
* @param pattern The input pattern
* @param result The output result.
*/
@Override
public void simulate(Column pattern, Column result) {
W.multiply(pattern, result); //inner product
result.add(b, result); //bias
result.apply(signum, result); //thresholding
}
public Matrix getWeights() {
return new Matrix(W);
}
}