ResilientBackpropagation.java
package libai.nn.supervised.backpropagation;
import libai.common.Shuffler;
import libai.common.matrix.Column;
import libai.common.matrix.Matrix;
import libai.nn.NeuralNetwork;
/**
* Created by kronenthaler on 18/01/2017.
*/
public class ResilientBackpropagation extends StandardBackpropagation {
protected double nPlus,
nMinus,
minUpdate,
maxUpdate,
initialUpdate;
public ResilientBackpropagation() {
this(1.2, 0.5, 1e-6, 50, 0.1);
}
public ResilientBackpropagation(double nPlus, double nMinus, double minUpdate, double maxUpdate, double initialUpdate) {
this.nPlus = nPlus;
this.nMinus = nMinus;
this.minUpdate = minUpdate;
this.maxUpdate = maxUpdate;
this.initialUpdate = initialUpdate;
}
@Override
public void train(Column[] patterns, Column[] answers, double alpha, int epochs, int offset, int length, double minerror) {
Shuffler shuffler = new Shuffler(length, NeuralNetwork.getDefaultRandomGenerator());
double error = nn.error(patterns, answers, offset, length);
Matrix e = new Matrix(answers[0].getRows(), answers[0].getColumns());
Matrix dacum[] = new Matrix[layers];
Matrix dacumPrev[] = new Matrix[layers];
Matrix updates[] = new Matrix[layers];
Matrix dacumb[] = new Matrix[layers];
Matrix dacumbPrev[] = new Matrix[layers];
Matrix updatesb[] = new Matrix[layers];
for (int i = 1; i < layers; i++) {
dacum[i] = new Matrix(u[i].getRows(), Y[i - 1].getRows());
dacumPrev[i] = new Matrix(u[i].getRows(), Y[i - 1].getRows());
updates[i] = new Matrix(u[i].getRows(), Y[i - 1].getRows());
updates[i].setValue(initialUpdate);
dacumb[i] = new Column(nperlayer[i]);
dacumbPrev[i] = new Column(nperlayer[i]);
updatesb[i] = new Column(nperlayer[i]);
updatesb[i].setValue(0.1);
}
for (int currentEpoch = 0; currentEpoch < epochs && error > minerror; currentEpoch++) {
//shuffle patterns
int[] sort = shuffler.shuffle();
error = 0;
for (int i = 0; i < length; i++) {
//Y[i]=Fi(<W[i],Y[i-1]>+b)
nn.simulate(patterns[sort[i] + offset]);
//e=-2(t-Y[n-1])
answers[sort[i] + offset].subtract(Y[layers - 1], e);
//calculate the error
for (int m = 0; m < nperlayer[layers - 1]; m++) {
error += (e.position(m, 0) * e.position(m, 0));
}
//d[0] = F0'(<W[i],Y[i-1]>).e
for (int j = 0; j < u[layers - 1].getRows(); j++) {
d[layers - 1].position(j, 0, -2 * func[layers - 1].getDerivate().eval(u[layers - 1].position(j, 0)) * e.position(j, 0));
}
//d[i]=Fi'(<W[i],Y[i-1]>).W[i+1]^t.d[i+1]
for (int k = layers - 2; k > 0; k--) {
for (int j = 0; j < u[k].getRows(); j++) {
double acum = 0;
for (int t = 0; t < W[k + 1].getRows(); t++) {
acum += W[k + 1].position(t, j) * d[k + 1].position(t, 0);
}
d[k].position(j, 0, acum * func[k].getDerivate().eval(u[k].position(j, 0)));
}
}
for (int l = 1; l < layers; l++) {
Y[l - 1].transpose(Yt[l - 1]);
d[l].multiply(Yt[l - 1], M[l]);
dacum[l].add(M[l], dacum[l]);
dacumb[l].add(d[l], dacumb[l]);
}
}
//update weights and thresholds
for (int l = 1; l < layers; l++) {
for (int i = 0; i < W[l].getRows(); i++) {
for (int j = 0; j < W[l].getColumns(); j++) {
double change = dacum[l].position(i, j) * dacumPrev[l].position(i, j);
double sign = dacum[l].position(i, j) > 0 ? 1 : -1;
if (change > 0) {
updates[l].position(i, j, Math.min(updates[l].position(i, j) * nPlus, maxUpdate));
W[l].increment(i, j, (-sign * updates[l].position(i, j)));
dacumPrev[l].position(i, j, dacum[l].position(i, j));
} else if (change < 0) {
updates[l].position(i, j, Math.max(updates[l].position(i, j) * nMinus, minUpdate));
dacumPrev[l].position(i, j, 0);
} else {
W[l].increment(i, j, (-sign * updates[l].position(i, j)));
dacumPrev[l].position(i, j, dacum[l].position(i, j));
}
dacum[l].position(i, j, 0);
}
for (int j = 0; j < b[l].getColumns(); j++) {
double change = dacumb[l].position(i, j) * dacumbPrev[l].position(i, j);
double sign = dacumb[l].position(i, j) > 0 ? 1 : -1;
if (change > 0) {
updatesb[l].position(i, j, Math.min(updatesb[l].position(i, j) * nPlus, maxUpdate));
b[l].increment(i, j, (-sign * updatesb[l].position(i, j)));
dacumbPrev[l].position(i, j, dacumb[l].position(i, j));
} else if (change < 0) {
updatesb[l].position(i, j, Math.max(updatesb[l].position(i, j) * nMinus, minUpdate));
dacumbPrev[l].position(i, j, 0);
} else {
b[l].increment(i, j, (-sign * updatesb[l].position(i, j)));
dacumbPrev[l].position(i, j, dacumb[l].position(i, j));
}
dacumb[l].position(i, j, 0);
}
}
}
error /= length;
if (nn.getPlotter() != null) {
nn.getPlotter().setError(currentEpoch, error);
}
if (nn.getProgressBar() != null) {
nn.getProgressBar().setValue(currentEpoch);
}
}
}
}