Hebb.java

/*
 * MIT License
 *
 * Copyright (c) 2009-2016 Ignacio Calderon <https://github.com/kronenthaler>
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in all
 * copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 * SOFTWARE.
 */
package libai.nn.supervised;

import libai.common.Shuffler;
import libai.common.matrix.Column;
import libai.common.matrix.Matrix;
import libai.common.functions.SymmetricSign;

import java.util.Random;

/**
 * Hebbian supervised networks are good for pattern retrieval and
 * reconstructions. These networks are only able to learn binary patterns
 * because its output function (symmetric sign). However, they can deal with
 * partially corrupted patterns and retrieve the original one without noise. The
 * Hebbian networks uses the Hebb's rule for training. The Hebb's rule is one of
 * the most important training rules in unsupervised networks. Other algorithms
 * like Kohonen uses this rule as base.
 *
 * @author kronenthaler
 */
public class Hebb extends SupervisedLearning {

    private static final long serialVersionUID = 7754681003525186940L;
    protected static final SymmetricSign sign = new SymmetricSign();
    protected double phi;
    protected Matrix W;

    /**
     * Constructor. Creates a Hebbian network with the equals number of inputs
     * and outputs. Set the decay constant to zero to eliminate it. Alias of
     * this(inputs, 0);
     *
     * @param inputs Number of inputs for the network.
     * @param outputs Number of outputs for the network.
     */
    public Hebb(int inputs, int outputs) {
        this(inputs, outputs, 0);
    }

    /**
     * Constructor. Creates a Hebbian network with the equals number of inputs
     * and outputs. Set the constant for decay <code>phi</code>. If phi = 0 the
     * network don't forget anything, if phi = 1 the network just remember the
     * las pattern.
     *
     * @param inputs Number of inputs and outputs for the networks.
     * @param outputs Number of outputs for the network.
     * @param phi Decay constant.
     */
    public Hebb(int inputs, int outputs, double phi) {
        this(inputs, outputs, phi, getDefaultRandomGenerator());
    }

    /**
     * Constructor. Creates a Hebbian network with the equals number of inputs
     * and outputs. Set the constant for decay <code>phi</code>. If phi = 0 the
     * network don't forget anything, if phi = 1 the network just remember the
     * las pattern.
     *
     * @param inputs Number of inputs and outputs for the networks.
     * @param outputs Number of outputs for the network.
     * @param phi Decay constant.
     * @param rand Random generator used for creating matrices
     */
    public Hebb(int inputs, int outputs, double phi, Random rand) {
        super(rand);
        this.phi = 1 - phi; //precalculation for the decay 1-phi
        W = new Matrix(outputs, inputs);
        W.setValue(0); // important!! the network should be initialized with 0
    }

    /**
     * Train the network using the Hebb's rule with decay. The hebb's rule,
     * consist on reinforce the right connections if they produce a correct
     * answer and inhibit the others. The decay term has an influence in how
     * much affects the previous knowledge to the reinforcement.
     *
     * @param patterns The patterns to be learned.
     * @param answers The expected answers.
     * @param alpha The learning rate.
     * @param epochs The maximum number of iterations
     * @param offset The first pattern position
     * @param length How many patterns will be used.
     * @param minerror The minimal error expected.
     */
    @Override
    public void train(Column[] patterns, Column[] answers, double alpha, int epochs, int offset, int length, double minerror) {
        validatePreconditions(patterns, answers, epochs, offset, length, minerror);

        Matrix[] patternsT = new Matrix[length];
        for (int i = 0; i < length; i++) {
            patternsT[i] = patterns[i + offset].transpose();
        }

        double error = 1;
        Shuffler shuffler = new Shuffler(length, this.random);
        initializeProgressBar(epochs);

        Matrix temp = new Matrix(W.getRows(), W.getColumns());
        for (int currentEpoch = 0; currentEpoch < epochs && error > minerror; currentEpoch++) {
            //shuffle patterns
            int[] sort = shuffler.shuffle();
            for (int i = 0; i < length; i++) {
                //F(wx)
                //simulate(patterns[sort[i] + offset], Y); // for unsupervised training
                Matrix Y = answers[sort[i] + offset];

                //W=(1-phi)*W + alpha*Y*pt;
                W.multiply(phi, W);
                Y.multiply(patternsT[sort[i]], temp);
                temp.multiply(alpha, temp);
                W.add(temp, W);
            }

            error = error(patterns, answers, offset, length);

            if (plotter != null) {
                plotter.setError(currentEpoch, error);
            }
            if (progress != null) {
                progress.setValue(epochs);
            }
        }

        if (progress != null) {
            progress.setValue(progress.getMaximum());
        }
    }

    @Override
    public Column simulate(Column pattern) {
        Column ret = new Column(W.getRows()); // must match the output size
        simulate(pattern, ret);
        return ret;
    }

    /**
     * Calculate the output for the pattern and left the result on result.
     * result = sign(W * pattern)
     *
     * @param pattern The input pattern
     * @param result The output result.
     */
    @Override
    public void simulate(Column pattern, Column result) {
        W.multiply(pattern, result);
        result.apply(sign, result);
    }

    public Matrix getWeights() {
        return new Matrix(W);
    }
}