SVM.java
/*
* MIT License
*
* Copyright (c) 2009-2016 Ignacio Calderon <https://github.com/kronenthaler>
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
package libai.nn.supervised;
import libai.common.Pair;
import libai.common.Precondition;
import libai.common.matrix.Column;
import libai.common.matrix.Matrix;
import libai.common.functions.SymmetricSign;
import libai.common.kernels.Kernel;
import java.util.Random;
/**
* Implementation of the SVM using the SMO algorithm. Based on the original
* implementation of:<br>
* X. Jiang and H. Yu. SVM-JAVA: A Java implementation of the SMO (Sequential
* Minimal Optimization) for training SVM.<br>
* Department of Computer Science and Engineering, Pohang University of Science
* and Technology (POSTECH), http://iis.hwanjoyu.org/svm-java, 2008. The code
* was adapted to the data structures and architecture of the libai. Some little
* optimization was made.
*
* @author kronenthaler
*/
public class SVM extends SupervisedLearning {
//static defs.
public static final int PARAM_C = 0;
public static final int PARAM_EPSILON = 1;
public static final int PARAM_TOLERANCE = 2;
private static final long serialVersionUID = 5875835056527034341L;
protected static final SymmetricSign ssign = new SymmetricSign();
// learning constants
private double C = 0.05;
private double tolerance = 0.001;
private double epsilon = 0.01;
// state of the neural network
private final Kernel kernel;
private double[] lambda;
/* Lagrange multipliers */
private double b = 0;
/* threshold */
private int[] target; // answers, need to be learned too.
private Matrix densePoints[]; // equivalent to W or matrix of prototype vectors
public SVM(Kernel _kernel) {
this(_kernel, getDefaultRandomGenerator());
}
public SVM(Kernel _kernel, Random rand) {
super(rand);
kernel = _kernel;
}
public void setTrainingParam(int param, double paramValue) {
switch (param) {
case PARAM_C:
C = paramValue;
break;
//other params...
case PARAM_EPSILON:
epsilon = paramValue;
break;
case PARAM_TOLERANCE:
tolerance = paramValue;
break;
default:
break;
}
}
@Override
protected void validatePreconditions(Column[] patterns, Column[] answers, int epochs, int offset, int length, double minerror) {
super.validatePreconditions(patterns, answers, epochs, offset, length, minerror);
Precondition.check(answers[0].getRows() == 1, "Answers can only be one-dimensional elements but %d-dimensions found", answers[0].getRows());
}
@Override
public void train(Column[] patterns, Column[] answers, double alpha, int epochs, int offset, int length, double minerror) {
validatePreconditions(patterns, answers, epochs, offset, length, minerror);
initializeProgressBar(epochs);
b = 0;
lambda = new double[length];
densePoints = new Matrix[length];
target = new int[length];
for (int i = 0; i < target.length; i++) {
densePoints[i] = new Matrix(patterns[i + offset]); //copy
target[i] = (int) ssign.eval(answers[i + offset].position(0, 0));
}
// pre-calculate the kernel values
final double[] errorCache = new double[length];
final double[][] precomputeKernels = precomputeKernels();
boolean changed = false;
boolean examineAll = true;
for (int currentEpoch = 0; currentEpoch < epochs && (changed || examineAll); currentEpoch++) {
changed = false;
for (int k = 0; k < length; k++) {
if (examineAll || (lambda[k] != 0 && lambda[k] != C)) {
changed |= examineExample(k, precomputeKernels, errorCache);
}
}
examineAll = !examineAll && !changed;
if (plotter != null) {
plotter.setError(epochs, error(patterns, answers, offset, length));
}
if (progress != null) {
progress.setValue(currentEpoch);
}
}
if (progress != null) {
progress.setValue(progress.getMaximum());
}
}
@Override
public Column simulate(Column pattern) {
final Column temp = new Column(1); //always returns a single class
simulate(pattern, temp);
return temp;
}
@Override
public void simulate(Column pattern, Column result) {
result.position(0, 0, ssign.eval(learnedFunction(pattern)));
}
@Override
public double error(Column[] patterns, Column[] answers, int offset, int length) {
Precondition.check(patterns.length == answers.length, "There must be the same amount of patterns and answers");
Precondition.check(offset >= 0 && offset < patterns.length, "offset must be in the interval [0, %d), found, %d", patterns.length, offset);
Precondition.check(length >= 0 && length <= patterns.length - offset, "length must be in the interval (0, %d], found, %d", patterns.length - offset, length);
int error = 0;
for (int i = 0; i < length; i++) {
if (simulate(patterns[i + offset]).position(0, 0) * answers[i + offset].position(0, 0) < 0) {
error++;
}
}
return error / (double) length;
}
// internal methods
private double[][] precomputeKernels() {
double[][] precomputed_kernels = new double[densePoints.length][densePoints.length];
for (int i = 0; i < precomputed_kernels.length; i++) {
for (int j = 0; j < precomputed_kernels.length; j++) {
precomputed_kernels[i][j] = kernel.eval(densePoints[i], densePoints[j]);
}
}
return precomputed_kernels;
}
private int findMaxDifference(double E1, double[] errorCache) {
int i2 = -1;
double tmax = 0;
for (int k = 0; k < errorCache.length; k++) {
if (!(0 < lambda[k] && lambda[k] < C)) {
continue;
}
double E2 = errorCache[k];
double temp = Math.abs(E1 - E2);
if (temp > tmax) {
tmax = temp;
i2 = k;
}
}
return i2;
}
private boolean examineExample(int i1, double[][] precomputedKernels, double[] errorCache) {
final int y1 = target[i1];
final double alph1 = lambda[i1];
final double E1 = partialError(i1, y1, alph1, errorCache);
final double r1 = y1 * E1;
if (!(r1 < -tolerance && alph1 < C) && !(r1 > tolerance && alph1 > 0)) {
return false;
}
final int i2 = findMaxDifference(E1, errorCache);
if (i2 >= 0 && takeStep(i1, i2, precomputedKernels, errorCache)) {
return true;
}
//first search if it's possible to take a step within the multipliers
int k = random.nextInt(lambda.length);
for (int t = 0; t < lambda.length; k = (k + 1) % lambda.length, t++) {
if (0 < lambda[k] && lambda[k] < C && takeStep(i1, k, precomputedKernels, errorCache)) {
return true;
}
}
// if no step is not possible within the multiplier take it from wherever possible.
k = random.nextInt(lambda.length);
for (int t = 0; t < lambda.length; k = (k + 1) % lambda.length, t++) {
if (takeStep(i1, k, precomputedKernels, errorCache)) {
return true;
}
}
return false;
}
private double partialError(int i, int y, double lambda, double[] errorCache) {
if (0 < lambda && lambda < C) {
return errorCache[i];
}
return learnedFunction(densePoints[i]) - y;
}
private boolean takeStep(int i1, int i2, double[][] precomputedKernels, double[] errorCache) {
if (i1 == i2) {
return false;
}
final double lambda1 = lambda[i1];
final int y1 = target[i1];
final double E1 = partialError(i1, y1, lambda1, errorCache);
final double lambda2 = lambda[i2];
final int y2 = target[i2];
final double E2 = partialError(i2, y2, lambda2, errorCache);
final Pair<Double, Double> range = getRange(y1, lambda1, y2, lambda2);
final double L = range.first;
final double H = range.second;
if (L == H) {
return false;
}
final double k11 = precomputedKernels[i1][i1];
final double k12 = precomputedKernels[i1][i2];
final double k22 = precomputedKernels[i2][i2];
final double eta = 2 * k12 - k11 - k22;
double l1 = 0;
double l2 = 0;
/* new values of lambda1, lambda2 */
if (eta < 0) {
l2 = lambda2 + y2 * (E2 - E1) / eta;
if (l2 < L) {
l2 = L;
} else if (l2 > H) {
l2 = H;
}
} else {
final double c1 = eta / 2;
final double c2 = y2 * (E1 - E2) - eta * lambda2;
final double Lobj = c1 * L * L + c2 * L;
final double Hobj = c1 * H * H + c2 * H;
l2 = lambda2;
if (Lobj > Hobj + epsilon) {
l2 = L;
} else if (Lobj < Hobj - epsilon) {
l2 = H;
}
}
if (Math.abs(l2 - lambda2) < epsilon * (l2 + lambda2 + epsilon)) {
return false;
}
final double s = y1 * y2;
l1 = lambda1 - s * (l2 - lambda2);
if (l1 < 0) {
l2 += s * l1;
l1 = 0;
} else if (l1 > C) {
l2 += s * (l1 - C);
l1 = C;
}
// update threshold and multipliers
b += getDeltaB(precomputedKernels, errorCache, i1, lambda1, E1, l1, i2, lambda2, E2, l2);
lambda[i1] = l1;
lambda[i2] = l2;
return true;
}
private Pair<Double, Double> getRange(double y1, double lambda1, double y2, double lambda2) {
double L = 0;
double H = 0;
if (y1 == y2) {
H = lambda1 + lambda2;
L = 0;
if (H > C) {
L = H - C;
H = C;
}
} else {
L = lambda2 - lambda1;
H = C;
if (-L > 0) {
L = 0;
H = C + L;
}
}
return new Pair<>(L, H);
}
private double getDeltaB(double[][] precomputedKernels, double[] errorCache,
int i1, double lambda1, double e1, double l1,
int i2, double lambda2, double e2, double l2) {
final int y1 = target[i1];
final int y2 = target[i2];
final double k11 = precomputedKernels[i1][i1];
final double k12 = precomputedKernels[i1][i2];
final double k22 = precomputedKernels[i2][i2];
final double deltaB;
final double t1 = y1 * (l1 - lambda1);
final double t2 = y2 * (l2 - lambda2);
final double b1 = e1 + t1 * k11 + t2 * k12;
final double b2 = e2 + t1 * k12 + t2 * k22;
if (l1 > 0 && l1 < C) {
deltaB = b1;
} else if (l2 > 0 && l2 < C) {
deltaB = b2;
} else {
deltaB = (b1 + b2) / 2;
}
// update error cache
for (int i = 0; i < lambda.length; i++) {
if (0 < lambda[i] && lambda[i] < C) {
double tmp = errorCache[i];
tmp += (t1 * precomputedKernels[i1][i]) + (t2 * precomputedKernels[i2][i]) - deltaB;
errorCache[i] = tmp;
}
}
errorCache[i1] = 0f;
errorCache[i2] = 0f;
return deltaB;
}
private double learnedFunction(Matrix pattern) {
double s = 0;
for (int i = 0; i < lambda.length; i++) {
if (lambda[i] > 0) {
s += lambda[i] * target[i] * kernel.eval(densePoints[i], pattern);
}
}
s -= b;
return s;
}
}