NeuralNetwork.java
/*
* MIT License
*
* Copyright (c) 2009-2016 Ignacio Calderon <https://github.com/kronenthaler>
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
package libai.nn;
import libai.common.matrix.Column;
import libai.common.Plotter;
import libai.common.Precondition;
import libai.common.ProgressDisplay;
import java.io.*;
import java.util.Random;
import java.util.concurrent.ThreadLocalRandom;
/**
* Neural network abstraction. Provides the methods to train, simulate and
* calculate the error.
*
* @author kronenthaler
*/
public abstract class NeuralNetwork implements Serializable {
private static final long serialVersionUID = 2851521924022998819L;
protected final Random random;
protected transient Plotter plotter;
protected transient ProgressDisplay progress;
public NeuralNetwork() {
this(getDefaultRandomGenerator());
}
public NeuralNetwork(Random rand) {
random = rand;
}
public static final Random getDefaultRandomGenerator() {
return ThreadLocalRandom.current();
}
public static final <NN extends NeuralNetwork> NN open(String path) throws IOException, ClassNotFoundException {
return (NN) open(new File(path));
}
public static final <NN extends NeuralNetwork> NN open(File file) throws IOException, ClassNotFoundException {
try (FileInputStream fis = new FileInputStream(file)) {
return (NN) open(fis);
}
}
public static final <NN extends NeuralNetwork> NN open(InputStream input) throws IOException, ClassNotFoundException {
try (ObjectInputStream in = new ObjectInputStream(input)) {
return (NN) in.readObject();
}
}
/**
* Calculates the square Euclidean distance between two vectors.
* <p>
* <i>NOTE:</i> Assertions of the dimensions are made with {@code assert}
* statement. You must enable this on runtime to be effective.</p>
*
* @param a Vector a.
* @param b Vector b.
* @return The square Euclidean distance.
*/
public static double euclideanDistance2(double[] a, double[] b) {
Precondition.check(a.length == b.length, "a & b must have the same length");
double sum = 0;
for (int i = 0; i < a.length; i++) {
double diff = (a[i] - b[i]);
sum += diff * diff;
}
return sum;
}
/**
* Calculates the square Euclidean distance between two column matrix.
* <p>
* <i>NOTE:</i> Assertions of the dimensions are made with {@code assert}
* statement. You must enable this on runtime to be effective.</p>
*
* @param a Column matrix a.
* @param b Column matrix b.
* @return The square Euclidean distance.
*/
public static double euclideanDistance2(Column a, Column b) {
Precondition.check(a.getRows() == b.getRows(), "a & b must have the same length");
double sum = 0;
for (int i = 0; i < a.getRows(); i++) {
double diff = (a.position(i, 0) - b.position(i, 0));
sum += diff * diff;
}
return sum;
}
/**
* Calculates the Gaussian function with standard deviation {@code sigma}
* and input parameter {@code u^2}
*
* @param u2 {@code u2}
* @param sigma {@code sigma}
* @return {@code e^(-u^2/2.sigma)}
*/
public static double gaussian(double u2, double sigma) {
return Math.exp((-u2) / (sigma * 2.0));
}
public Plotter getPlotter() {
return plotter;
}
public void setPlotter(Plotter plotter) {
this.plotter = plotter;
}
public ProgressDisplay getProgressBar() {
return progress;
}
/**
* Sets a {@link ProgressDisplay} to the {@code NeuralNetwork}. The value
* will go from {@code -epochs} to {@code 0}, and updated every training
* epoch.
* <p>
* <i>Note: </i> Classes that implement {@link
* NeuralNetwork#train(Column[], Column[], double, int, int, int, double)}
* are responsible for this behavior.</p>
*
* @param pb ProgressDisplay
*/
public void setProgressBar(ProgressDisplay pb) {
progress = pb;
}
/**
* Trains this neural network with the list of {@code patterns} and the
* expected {@code answers}.
* <p>
* Use the learning rate {@code alpha} for many {@code epochs}. Take
* {@code length} patterns from the position {@code offset} until the
* {@code minerror} is reached.</p>
* <p>
* {@code patterns} and {@code answers} must be arrays of non-{@code null}
* <b>column</b> matrices</p>
*
* @param patterns The patterns to be learned.
* @param answers The expected answers.
* @param alpha The learning rate.
* @param epochs The maximum number of iterations
* @param offset The first pattern position
* @param length How many patterns will be used.
* @param minerror The minimal error expected.
*/
public abstract void train(Column[] patterns, Column[] answers, double alpha, int epochs, int offset, int length, double minerror);
/**
* Alias of train(patterns, answers, alpha, epochs, 0, patterns.length,
* 1.e-5).
* <p>
* {@code patterns} and {@code answers} must be arrays of non-{@code null}
* <b>column</b> matrices</p>
*
* @param patterns The patterns to be learned.
* @param answers The expected answers.
* @param alpha The learning rate.
* @param epochs The maximum number of iterations
* @see NeuralNetwork#train(Column[], Column[], double, int, int, int,
* double)
*/
public void train(Column[] patterns, Column[] answers, double alpha, int epochs) {
train(patterns, answers, alpha, epochs, 0, patterns.length, 1.e-5);
}
/**
* Alias of train(patterns, answers, alpha, epochs, offset, length, 1.e-5).
* <p>
* {@code patterns} and {@code answers} must be arrays of non-{@code null}
* <b>column</b> matrices</p>
*
* @param patterns The patterns to be learned.
* @param answers The expected answers.
* @param alpha The learning rate.
* @param epochs The maximum number of iterations
* @param offset The first pattern position
* @param length How many patterns will be used.
* @see NeuralNetwork#train(Column[], Column[], double, int, int, int,
* double)
*/
public void train(Column[] patterns, Column[] answers, double alpha, int epochs, int offset, int length) {
train(patterns, answers, alpha, epochs, offset, length, 1.e-5);
}
/**
* Calculates the output for the {@code pattern}.
*
* @param pattern Pattern to use as input.
* @return The output for the neural network.
*/
public abstract Column simulate(Column pattern);
/**
* Calculates the output for the {@code pattern} and left the result in
* {@code result}.
*
* @param pattern Pattern to use as input.
* @param result The output for the input.
*/
public abstract void simulate(Column pattern, Column result);
/**
* Saves the neural network to the file in the given {@code path}
*
* @param path The path for the output file.
* @return {@code true} if the file can be created and written,
* {@code false} otherwise.
*/
public boolean save(String path) {
try (FileOutputStream fos = new FileOutputStream(path);
ObjectOutputStream oos = new ObjectOutputStream(fos)) {
oos.writeObject(this);
} catch (Exception e) {
return false;
}
return true;
}
/**
* Calculates from a set of patterns. Alias of error(patterns, answers, 0,
* patterns.length)
* <p>
* {@code patterns} and {@code answers} must be arrays of non-{@code null}
* <b>column</b> matrices</p>
*
* @param patterns The array with the patterns to test
* @param answers The array with the expected answers for the patterns.
* @return The error calculate for the patterns.
* @see NeuralNetwork#error(Column[], Column[], int, int)
*/
public double error(Column[] patterns, Column[] answers) {
return error(patterns, answers, 0, patterns.length);
}
/**
* Calculates the mean quadratic error. It is the standard error metric for
* neural networks. Just a few networks needs a different type of error
* metric.
* <p>
* {@code patterns} and {@code answers} must be arrays of non-{@code null}
* <b>column</b> matrices</p>
* <p>
* <i>NOTE:</i> Assertions of the dimensions are made with {@code assert}
* statement. You must enable this on runtime to be effective.</p>
*
* @param patterns The array with the patterns to test
* @param answers The array with the expected answers for the patterns.
* @param offset The initial position inside the array.
* @param length How many patterns must be taken from the offset.
* @return The mean quadratic error.
*/
public double error(Column[] patterns, Column[] answers, int offset, int length) {
Precondition.check(patterns.length == answers.length, "There must be the same amount of patterns and answers");
Precondition.check(offset >= 0 && offset < patterns.length, "offset must be in the interval [0, %d), found, %d", patterns.length, offset);
Precondition.check(length >= 0 && length <= patterns.length - offset, "length must be in the interval (0, %d], found, %d", patterns.length - offset, length);
double error = 0.0;
Column Y = new Column(answers[0].getRows());
for (int i = 0; i < length; i++) {
simulate(patterns[i + offset], Y); //inner product
error += euclideanDistance2(answers[i + offset], Y);
}
return error / (double) length;
}
protected void initializeProgressBar(int maximum) {
if (progress != null) {
progress.setMaximum(maximum);
progress.setMinimum(0);
progress.setValue(0);
}
}
}