StandardBackpropagation.java

package libai.nn.supervised.backpropagation;

import libai.common.Shuffler;
import libai.common.matrix.Column;
import libai.common.matrix.Matrix;
import libai.common.functions.Function;
import libai.common.matrix.Row;
import libai.nn.NeuralNetwork;

/**
 * Created by kronenthaler on 08/01/2017.
 */
public class StandardBackpropagation implements Backpropagation {

    protected NeuralNetwork nn;
    protected int[] nperlayer;
    protected int layers;
    protected Function[] func;
    protected Matrix[] W;
    protected Column[] Y, b, u;
    // auxiliary buffers
    protected Column[] d;
    protected Row[] Yt;
    protected Matrix[] M;

    @Override
    public void initialize(NeuralNetwork nn, int[] nperlayer, Function[] functions, Matrix[] W, Column[] Y, Column[] b, Column[] u) {
        this.nn = nn;
        this.nperlayer = nperlayer;
        this.layers = nperlayer.length;
        this.func = functions;
        this.W = W;
        this.b = b;
        this.Y = Y;
        this.u = u;

        initialize();
    }

    private void initialize() {
        d = new Column[layers];//position zero reserved
        Yt = new Row[layers];
        M = new Matrix[layers];

        Yt[0] = new Row(nperlayer[0]);
        for (int i = 1; i < layers; i++) {
            Yt[i] = new Row(u[i].getRows());
            M[i] = new Matrix(u[i].getRows(), Y[i - 1].getRows());
        }

        d[layers - 1] = new Column(u[layers - 1].getRows());
        for (int k = layers - 2; k > 0; k--) {
            d[k] = new Column(u[k].getRows());
        }
    }

    @Override
    public void train(Column[] patterns, Column[] answers, double alpha, int epochs, int offset, int length, double minerror) {
        Shuffler shuffler = new Shuffler(length, NeuralNetwork.getDefaultRandomGenerator());

        double error = nn.error(patterns, answers, offset, length);
        Matrix e = new Matrix(answers[0].getRows(), answers[0].getColumns());

        for (int currentEpoch = 0; currentEpoch < epochs && error > minerror; currentEpoch++) {
            //shuffle patterns
            int[] sort = shuffler.shuffle();

            error = 0;
            for (int i = 0; i < length; i++) {
                //Y[i]=Fi(<W[i],Y[i-1]>+b)
                nn.simulate(patterns[sort[i] + offset]);

                //e=-2(t-Y[n-1])
                answers[sort[i] + offset].subtract(Y[layers - 1], e);

                //calculate the error
                for (int m = 0; m < nperlayer[layers - 1]; m++) {
                    error += (e.position(m, 0) * e.position(m, 0));
                }

                //d[0] = F0'(<W[i],Y[i-1]>).e
                for (int j = 0; j < u[layers - 1].getRows(); j++) {
                    d[layers - 1].position(j, 0, -2 * alpha * func[layers - 1].getDerivate().eval(u[layers - 1].position(j, 0)) * e.position(j, 0));
                }

                //d[i]=Fi'(<W[i],Y[i-1]>).W[i+1]^t.d[i+1]
                for (int k = layers - 2; k > 0; k--) {
                    for (int j = 0; j < u[k].getRows(); j++) {
                        double acum = 0;
                        for (int t = 0; t < W[k + 1].getRows(); t++) {
                            acum += W[k + 1].position(t, j) * d[k + 1].position(t, 0);
                        }
                        d[k].position(j, 0, alpha * acum * func[k].getDerivate().eval(u[k].position(j, 0)));
                    }
                }

                //update weights and thresholds
                for (int l = 1; l < layers; l++) {
                    Y[l - 1].transpose(Yt[l - 1]);
                    d[l].multiply(Yt[l - 1], M[l]);
                    W[l].subtract(M[l], W[l]);
                    b[l].subtract(d[l], b[l]);
                }
            }

            error /= length;

            if (nn.getPlotter() != null) {
                nn.getPlotter().setError(currentEpoch, error);
            }
            if (nn.getProgressBar() != null) {
                nn.getProgressBar().setValue(currentEpoch);
            }
        }
    }
}