Class NeuralNetwork
- java.lang.Object
-
- libai.nn.NeuralNetwork
-
- All Implemented Interfaces:
java.io.Serializable
- Direct Known Subclasses:
SupervisedLearning
,UnsupervisedLearning
public abstract class NeuralNetwork extends java.lang.Object implements java.io.Serializable
Neural network abstraction. Provides the methods to train, simulate and calculate the error.- See Also:
- Serialized Form
-
-
Field Summary
Fields Modifier and Type Field Description protected Plotter
plotter
protected ProgressDisplay
progress
protected java.util.Random
random
-
Constructor Summary
Constructors Constructor Description NeuralNetwork()
NeuralNetwork(java.util.Random rand)
-
Method Summary
All Methods Static Methods Instance Methods Abstract Methods Concrete Methods Modifier and Type Method Description double
error(Column[] patterns, Column[] answers)
Calculates from a set of patterns.double
error(Column[] patterns, Column[] answers, int offset, int length)
Calculates the mean quadratic error.static double
euclideanDistance2(double[] a, double[] b)
Calculates the square Euclidean distance between two vectors.static double
euclideanDistance2(Column a, Column b)
Calculates the square Euclidean distance between two column matrix.static double
gaussian(double u2, double sigma)
Calculates the Gaussian function with standard deviationsigma
and input parameteru^2
static java.util.Random
getDefaultRandomGenerator()
Plotter
getPlotter()
ProgressDisplay
getProgressBar()
protected void
initializeProgressBar(int maximum)
static <NN extends NeuralNetwork>
NNopen(java.io.File file)
static <NN extends NeuralNetwork>
NNopen(java.io.InputStream input)
static <NN extends NeuralNetwork>
NNopen(java.lang.String path)
boolean
save(java.lang.String path)
Saves the neural network to the file in the givenpath
void
setPlotter(Plotter plotter)
void
setProgressBar(ProgressDisplay pb)
Sets aProgressDisplay
to theNeuralNetwork
.abstract Column
simulate(Column pattern)
Calculates the output for thepattern
.abstract void
simulate(Column pattern, Column result)
Calculates the output for thepattern
and left the result inresult
.void
train(Column[] patterns, Column[] answers, double alpha, int epochs)
Alias of train(patterns, answers, alpha, epochs, 0, patterns.length, 1.e-5).void
train(Column[] patterns, Column[] answers, double alpha, int epochs, int offset, int length)
Alias of train(patterns, answers, alpha, epochs, offset, length, 1.e-5).abstract void
train(Column[] patterns, Column[] answers, double alpha, int epochs, int offset, int length, double minerror)
Trains this neural network with the list ofpatterns
and the expectedanswers
.
-
-
-
Field Detail
-
random
protected final java.util.Random random
-
plotter
protected transient Plotter plotter
-
progress
protected transient ProgressDisplay progress
-
-
Method Detail
-
getDefaultRandomGenerator
public static final java.util.Random getDefaultRandomGenerator()
-
open
public static final <NN extends NeuralNetwork> NN open(java.lang.String path) throws java.io.IOException, java.lang.ClassNotFoundException
- Throws:
java.io.IOException
java.lang.ClassNotFoundException
-
open
public static final <NN extends NeuralNetwork> NN open(java.io.File file) throws java.io.IOException, java.lang.ClassNotFoundException
- Throws:
java.io.IOException
java.lang.ClassNotFoundException
-
open
public static final <NN extends NeuralNetwork> NN open(java.io.InputStream input) throws java.io.IOException, java.lang.ClassNotFoundException
- Throws:
java.io.IOException
java.lang.ClassNotFoundException
-
euclideanDistance2
public static double euclideanDistance2(double[] a, double[] b)
Calculates the square Euclidean distance between two vectors.NOTE: Assertions of the dimensions are made with
assert
statement. You must enable this on runtime to be effective.- Parameters:
a
- Vector a.b
- Vector b.- Returns:
- The square Euclidean distance.
-
euclideanDistance2
public static double euclideanDistance2(Column a, Column b)
Calculates the square Euclidean distance between two column matrix.NOTE: Assertions of the dimensions are made with
assert
statement. You must enable this on runtime to be effective.- Parameters:
a
- Column matrix a.b
- Column matrix b.- Returns:
- The square Euclidean distance.
-
gaussian
public static double gaussian(double u2, double sigma)
Calculates the Gaussian function with standard deviationsigma
and input parameteru^2
- Parameters:
u2
-u2
sigma
-sigma
- Returns:
e^(-u^2/2.sigma)
-
getPlotter
public Plotter getPlotter()
-
setPlotter
public void setPlotter(Plotter plotter)
-
getProgressBar
public ProgressDisplay getProgressBar()
-
setProgressBar
public void setProgressBar(ProgressDisplay pb)
Sets aProgressDisplay
to theNeuralNetwork
. The value will go from-epochs
to0
, and updated every training epoch.Note: Classes that implement
train(Column[], Column[], double, int, int, int, double)
are responsible for this behavior.- Parameters:
pb
- ProgressDisplay
-
train
public abstract void train(Column[] patterns, Column[] answers, double alpha, int epochs, int offset, int length, double minerror)
Trains this neural network with the list ofpatterns
and the expectedanswers
.Use the learning rate
alpha
for manyepochs
. Takelength
patterns from the positionoffset
until theminerror
is reached.patterns
andanswers
must be arrays of non-null
column matrices- Parameters:
patterns
- The patterns to be learned.answers
- The expected answers.alpha
- The learning rate.epochs
- The maximum number of iterationsoffset
- The first pattern positionlength
- How many patterns will be used.minerror
- The minimal error expected.
-
train
public void train(Column[] patterns, Column[] answers, double alpha, int epochs)
Alias of train(patterns, answers, alpha, epochs, 0, patterns.length, 1.e-5).patterns
andanswers
must be arrays of non-null
column matrices- Parameters:
patterns
- The patterns to be learned.answers
- The expected answers.alpha
- The learning rate.epochs
- The maximum number of iterations- See Also:
train(Column[], Column[], double, int, int, int, double)
-
train
public void train(Column[] patterns, Column[] answers, double alpha, int epochs, int offset, int length)
Alias of train(patterns, answers, alpha, epochs, offset, length, 1.e-5).patterns
andanswers
must be arrays of non-null
column matrices- Parameters:
patterns
- The patterns to be learned.answers
- The expected answers.alpha
- The learning rate.epochs
- The maximum number of iterationsoffset
- The first pattern positionlength
- How many patterns will be used.- See Also:
train(Column[], Column[], double, int, int, int, double)
-
simulate
public abstract Column simulate(Column pattern)
Calculates the output for thepattern
.- Parameters:
pattern
- Pattern to use as input.- Returns:
- The output for the neural network.
-
simulate
public abstract void simulate(Column pattern, Column result)
Calculates the output for thepattern
and left the result inresult
.- Parameters:
pattern
- Pattern to use as input.result
- The output for the input.
-
save
public boolean save(java.lang.String path)
Saves the neural network to the file in the givenpath
- Parameters:
path
- The path for the output file.- Returns:
true
if the file can be created and written,false
otherwise.
-
error
public double error(Column[] patterns, Column[] answers)
Calculates from a set of patterns. Alias of error(patterns, answers, 0, patterns.length)patterns
andanswers
must be arrays of non-null
column matrices- Parameters:
patterns
- The array with the patterns to testanswers
- The array with the expected answers for the patterns.- Returns:
- The error calculate for the patterns.
- See Also:
error(Column[], Column[], int, int)
-
error
public double error(Column[] patterns, Column[] answers, int offset, int length)
Calculates the mean quadratic error. It is the standard error metric for neural networks. Just a few networks needs a different type of error metric.patterns
andanswers
must be arrays of non-null
column matricesNOTE: Assertions of the dimensions are made with
assert
statement. You must enable this on runtime to be effective.- Parameters:
patterns
- The array with the patterns to testanswers
- The array with the expected answers for the patterns.offset
- The initial position inside the array.length
- How many patterns must be taken from the offset.- Returns:
- The mean quadratic error.
-
initializeProgressBar
protected void initializeProgressBar(int maximum)
-
-