C45.java
/*
* MIT License
*
* Copyright (c) 2009-2016 Ignacio Calderon <https://github.com/kronenthaler>
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
package libai.classifiers.trees;
import libai.classifiers.Attribute;
import libai.classifiers.ContinuousAttribute;
import libai.classifiers.DiscreteAttribute;
import libai.classifiers.dataset.DataSet;
import libai.classifiers.dataset.MetaData;
import libai.common.Pair;
import org.w3c.dom.Document;
import org.w3c.dom.Node;
import org.w3c.dom.NodeList;
import javax.xml.parsers.DocumentBuilder;
import javax.xml.parsers.DocumentBuilderFactory;
import java.io.*;
import java.util.*;
/**
* TODO: missing values.
*
* @author kronenthaler
*/
public class C45 implements Comparable<C45> {
public static final int NO_PRUNE = 0;
public static final int QUINLANS_PRUNE = 1;
public static final int LAPLACE_PRUNE = 2;
protected Attribute output;
protected Pair<Attribute, C45> childs[];
protected double error;
protected double backedUpError;
//prune variables
protected Attribute mostCommonLeaf;
//Laplace's error pruning
protected int mostCommonLeafFreq = Integer.MIN_VALUE;
protected int samplesCount; //how many samples pass for this node in the pruning process.
protected HashMap<Attribute, Integer> samplesFreq = new HashMap<>(); //used to the pruning process.
//Quinlan's prunning using confidence
protected double confidence = 0.25;
protected double z;
protected int good, bad;
//constructors
public C45() {
setConfidence(confidence);
}
protected C45(Attribute root) {
this();
output = root;
}
protected C45(Pair<Attribute, C45>[] c) {
this();
childs = c;
}
protected C45(ArrayList<Pair<Attribute, C45>> c) {
this();
childs = new Pair[c.size()];
for (int i = 0, n = childs.length; i < n; i++) {
childs[i] = c.get(i);
}
}
//Factories
public static C45 getInstance(File path) {
try (FileInputStream fis = new FileInputStream(path)){
DocumentBuilderFactory dbf = DocumentBuilderFactory.newInstance();
DocumentBuilder db = dbf.newDocumentBuilder();
Document doc = db.parse(fis);
NodeList root = doc.getElementsByTagName("C45").item(0).getChildNodes();
for (int i = 0; i < root.getLength(); i++) {
Node current = root.item(i);
if (current.getNodeName().equals("node") ||
current.getNodeName().equals("leaf")) {
return new C45().load(current);
}
}
return null;
} catch (Exception e) {
e.printStackTrace();
return null;
}
}
/**
* Return an unpruned tree from the given dataset.
*
* @param ds {@code ds}
* @return unpruned tree from the given dataset
*/
public static C45 getInstance(DataSet ds) {
return new C45().train(ds);
}
/**
* Return a pruned tree from the given dataset using the standard confidence
* of 25%
*
* @param ds {@code ds}
* @param type {@code type}
* @return pruned tree from the given dataset using the standard confidence
*/
public static C45 getInstancePrune(DataSet ds, int type) {
return new C45().train(ds).prune(ds, type);
}
/**
* Return a pruned tree from the given dataset using the specified
* confidence.
*
* @param ds {@code ds}
* @param confidence {@code confidence}
* @return pruned tree from the given dataset using the specified
*/
public static C45 getInstancePrune(DataSet ds, double confidence) {
C45 ret = new C45();
ret = ret.train(ds);
ret.setConfidence(confidence);
return ret.prune(ds, QUINLANS_PRUNE);
}
//Tree related
public boolean isLeaf() {
return (childs == null || childs.length == 0) && output != null;
}
public Attribute eval(List<Attribute> record, DataSet ds) {
return eval(record, false, null, ds);
}
private Attribute eval(List<Attribute> record, boolean keeptrack, Attribute expected, DataSet ds) {
if (keeptrack) {
//laplace pruning
if (samplesFreq.get(expected) == null) {
for (Attribute att : ds.getMetaData().getClasses()) {
samplesFreq.put(att, 0);
}
}
samplesFreq.put(expected, samplesFreq.get(expected) + 1);
if (mostCommonLeafFreq < samplesFreq.get(expected)) {
mostCommonLeafFreq = samplesFreq.get(expected);
mostCommonLeaf = expected;
}
samplesCount++;
}
if (isLeaf()) {
if (keeptrack) {
//quinlan pruning
if (output.compareTo(expected) == 0) {
good++;
} else {
bad++;
}
}
return output;
} else {
if (childs[0].first.isCategorical()) {
for (Pair<Attribute, C45> p : childs) {
if (record.contains(p.first)) {
return p.second.eval(record, keeptrack, expected, ds);
}
}
} else {
try {
for (int i = 0; i < ds.getMetaData().getAttributeCount(); i++) {
if (ds.getMetaData().getAttributeName(i).equals(childs[0].first.getName())) {
if (record.get(i).compareTo(childs[0].first) <= 0) {
return childs[0].second.eval(record, keeptrack, expected, ds);
} else {
return childs[1].second.eval(record, keeptrack, expected, ds);
}
}
}
} catch (Exception ex) {
ex.printStackTrace();
}
}
}
return null; //no prediction
}
public C45 train(DataSet ds) {
final HashSet<Integer> visited = new HashSet<>();
visited.add(ds.getOutputIndex());
return train(ds, visited, "");
}
private C45 train(DataSet ds, HashSet<Integer> visited, String deep) {
final MetaData metadata = ds.getMetaData();
final int attributeCount = metadata.getAttributeCount();
final int outputIndex = ds.getOutputIndex();
final int itemsCount = ds.getItemsCount();
//no more attributes to visit or no more items to check
if (itemsCount == 0 || visited.size() == attributeCount) {
return null;
}
//base case: all the output are the same.
if (ds.allTheSameOutput()) {
return new C45(ds.iterator().next().get(outputIndex));
}
//base case: all the attributes are the same.
final Attribute att = ds.allTheSame();
if (att != null) {
return new C45(att);
}
//else
double max = -Double.MIN_VALUE;
int index = -1;
int indexOfValue = -1;
double splitValue = Double.MIN_VALUE;
for (int i = 0; i < attributeCount; i++) {
if (!visited.contains(i)) {
GainInformation gain = gain(ds, 0, itemsCount, i); //get the maximun gain ratio.
if (gain.ratio > max) {
max = gain.ratio;
index = i; //split attribute
if (!metadata.isCategorical(i)) {
splitValue = gain.splitValue;
indexOfValue = gain.indexOfSplitValue;
}
}
}
}
final Iterable<List<Attribute>> sortedRecords = ds.sortOver(index);
final ArrayList<Pair<Attribute, C45>> children = new ArrayList<>();
if (metadata.isCategorical(index)) {
visited.add(index); //mark as ready, avoid revisiting a nominal attribute.
Iterator<List<Attribute>> records = sortedRecords.iterator();
records.hasNext(); //just to move the pointer.
Attribute prev = records.next().get(index);
Attribute current;
int nlo, i = 0;
while (prev != null) {
current = null;
nlo = i;
while (records.hasNext()) {
current = records.next().get(index);
if (!prev.equals(current)) {
break;
}
i++;
}
i++;
DataSet section = ds.getSubset(nlo, i);
children.add(new Pair<>(prev, train(section, visited, deep + "\t")));
section.close();
prev = current;
}
} else {
DataSet l = ds.getSubset(0, indexOfValue);
DataSet r = ds.getSubset(indexOfValue, itemsCount);
String fieldName = ds.getMetaData().getAttributeName(index);
C45 left = train(l, visited, deep + "\t");
children.add(new Pair<>(Attribute.getInstance(splitValue, fieldName), left));
C45 right = train(r, visited, deep + "\t");
children.add(new Pair<>(Attribute.getInstance(splitValue, fieldName), right));
l.close();
r.close();
}
return new C45(children);
}
private GainInformation gain(DataSet ds, int lo, int hi, int fieldIndex) {
final DiscreteEntropyInformation info = (DiscreteEntropyInformation) infoAvg(ds, lo, hi, fieldIndex);
final double gain = info(ds, 0, ds.getItemsCount(), ds.getOutputIndex()) - info.maxInfo;
final double gainRatio = gain / info.maxSplitInfo;
final GainInformation result = new GainInformation();
result.gain = gain;
result.ratio = gainRatio;
result.maxInfo = info.maxInfo;
result.maxSplitInfo = info.maxSplitInfo;
if (info instanceof ContinuousEntropyInformation) {
result.splitValue = ((ContinuousEntropyInformation) info).splitValue;
result.indexOfSplitValue = ((ContinuousEntropyInformation) info).indexOfSplitValue;
}
return result;
}
private EntropyInformation infoAvg(DataSet ds, int lo, int hi, int fieldIndex) {
if (ds.getMetaData().isCategorical(fieldIndex)) {
return infoAvgDiscrete(ds, lo, hi, fieldIndex);
} else {
return infoAvgContinuous(ds, lo, hi, fieldIndex);
}
}
private DiscreteEntropyInformation infoAvgDiscrete(DataSet ds, int lo, int hi, int fieldIndex) {
final DiscreteEntropyInformation info = new DiscreteEntropyInformation();
final double total = hi - lo;
Attribute prev = null, current = null;
final Iterable<List<Attribute>> sortedRecords = ds.sortOver(lo, hi, fieldIndex);
final Iterator<List<Attribute>> records = sortedRecords.iterator();
prev = records.next().get(fieldIndex);
int nlo;
int i = lo;
while (prev != null) {
current = null;
nlo = i;
while (records.hasNext()) {
current = records.next().get(fieldIndex);
if (!prev.equals(current)) {
break;
}
i++;
}
i++;
prev = current;
final double res = info(ds, nlo, i, ds.getOutputIndex());
final double chunkSize = i - nlo;
info.maxInfo += res * (chunkSize / total);
info.maxSplitInfo += -(chunkSize / total) * (Math.log10(chunkSize / total) / Math.log10(2));
}
return info;
}
private ContinuousEntropyInformation infoAvgContinuous(DataSet ds, int lo, int hi, int fieldIndex) {
final HashMap<Attribute, Integer> totalFreq = ds.getFrequencies(lo, hi, ds.getOutputIndex());
final HashMap<Double, HashMap<Attribute, Integer>> freqAcum = getAccumulatedFrequencies(ds, lo, hi, fieldIndex);
double splitInfo = 0;
double total = hi - lo;
double maxInfo = -Double.MIN_VALUE;
double maxSplitInfo = -Double.MIN_VALUE;
double bestSplitValue = Integer.MAX_VALUE;
int bestIndex = Integer.MIN_VALUE;
for (Attribute key : ds.getMetaData().getClasses()) {
if (totalFreq.get(key) == null) {
totalFreq.put(key, 0);
}
}
final Iterable<List<Attribute>> records = ds.sortOver(lo, hi, fieldIndex);
int i = lo;
for (List<Attribute> record : records) {
final double value = ((ContinuousAttribute) record.get(fieldIndex)).getValue();
final HashMap<Attribute, Integer> freq = freqAcum.get(value);
final double total2 = i - lo + 1;
double acum2 = 0;
final double total3 = total - total2;
double acum3 = 0;
for (Map.Entry<Attribute, Integer> en : freq.entrySet()) {
int f = en.getValue();
if (f > 0) {
double p = (f / total2);
acum2 += -p * (Math.log10(p) / Math.log10(2));
}
f = totalFreq.get(en.getKey()) - en.getValue();
if (f > 0) {
double p = f / total3;
acum3 += -p * (Math.log10(p) / Math.log10(2));
}
}
final double infoA = (total2 / total) * acum2;
final double infoB = (total3 / total) * acum3;
splitInfo = 0;
if ((int) total2 != 0) {
splitInfo += -(total2 / total) * (Math.log10(total2 / total) / Math.log10(2));
}
if ((int) total3 != 0) {
splitInfo += -(total3 / total) * (Math.log10(total3 / total) / Math.log10(2));
}
if (splitInfo > maxSplitInfo) {
maxInfo = infoA + infoB;
int k = i + 1;
Iterable<List<Attribute>> subSet = ds.sortOver(i + 1, hi, fieldIndex);
for (List<Attribute> subSetRecord : subSet) {
double nextValue = ((ContinuousAttribute) subSetRecord.get(fieldIndex)).getValue();
if (value != nextValue) {
bestSplitValue = (value + nextValue) / 2;
bestIndex = k;
break;
}
k++;
}
if (k == hi) {
bestSplitValue = value;
bestIndex = hi - 1;
}
maxSplitInfo = splitInfo;
}
i++;
}
ContinuousEntropyInformation result = new ContinuousEntropyInformation();
result.maxInfo = maxInfo;
result.maxSplitInfo = maxSplitInfo;
result.splitValue = bestSplitValue;
result.indexOfSplitValue = bestIndex;
return result;
}
private double info(DataSet ds, int lo, int hi, int fieldIndex) {
final HashMap<Attribute, Integer> freq = ds.getFrequencies(lo, hi, fieldIndex);
final double total = hi - lo;
double acum = 0;
for (Integer f : freq.values()) {
if (f != 0) {
double p = (f / total);
acum += -p * (Math.log10(p) / Math.log10(2));
}
}
return acum;
}
private HashMap<Double, HashMap<Attribute, Integer>> getAccumulatedFrequencies(DataSet ds, int lo, int hi, int fieldIndex) {
final Iterable<List<Attribute>> records = ds.sortOver(lo, hi, fieldIndex);
final DataSet aux = ds.getSubset(lo, hi);
List<Attribute> prev = null;
final HashMap<Double, HashMap<Attribute, Integer>> freqAcum = new HashMap<>();
for (List<Attribute> record : records) {
double va = ((ContinuousAttribute) record.get(fieldIndex)).getValue();
Attribute v = record.get(ds.getOutputIndex());
if (freqAcum.get(va) == null) {
freqAcum.put(va, new HashMap<>());
for (Attribute c : aux.getMetaData().getClasses()) {
if (prev == null) {
freqAcum.get(va).put(c, 0);
} else {
double pva = ((ContinuousAttribute) prev.get(fieldIndex)).getValue();
freqAcum.get(va).put(c, freqAcum.get(pva).get(c));
}
}
}
prev = record;
HashMap<Attribute, Integer> m = freqAcum.get(va);
m.put(v, m.get(v) + 1);
}
aux.close();
return freqAcum;
}
public double error(DataSet ds) {
int errorCount = 0;
for (List<Attribute> record : ds) {
if ((eval(record, ds).compareTo(record.get(ds.getOutputIndex())) != 0)) {
errorCount++;
}
}
return errorCount / (double) ds.getItemsCount();
}
public C45 prune(DataSet ds, int type) {
//first of all, evaluate all the data set over the tree, and keep track of the results.
final int outputIndex = ds.getOutputIndex();
for (List<Attribute> record : ds) {
eval(record, true, record.get(outputIndex), ds);
}
prune(type);
return this;
}
private void prune(int prunningType) {
if (isLeaf()) {
if (prunningType == QUINLANS_PRUNE) {
error = confidenceError(1.0 / (double) (bad + good), good / (double) (bad + good));
} else if (prunningType == LAPLACE_PRUNE) {
error = laplaceError(samplesCount, mostCommonLeafFreq, samplesFreq.size());
}
} else {
backedUpError = 0;
for (Pair<Attribute, C45> c : childs) {
c.second.prune(prunningType);
if (prunningType == QUINLANS_PRUNE) {
good += c.second.good;
bad += c.second.bad;
backedUpError += c.second.error * (c.second.good + c.second.bad);
} else if (prunningType == LAPLACE_PRUNE) {
backedUpError += c.second.error * c.second.samplesCount;
}
}
if (prunningType == QUINLANS_PRUNE) {
error = confidenceError(1.0 / (double) (bad + good), good / (double) (bad + good));
backedUpError /= (double) (good + bad);
} else if (prunningType == LAPLACE_PRUNE) {
error = laplaceError(samplesCount, mostCommonLeafFreq, samplesFreq.size());
backedUpError /= (double) samplesCount;
}
if (error < backedUpError) {
childs = null;
output = mostCommonLeaf;
}
error = Math.min(error, backedUpError);
}
}
private double laplaceError(int N, int n, int k) {
return (double) (N - n + k - 1) / (double) (N + k);
}
// QUINLAN'S prunning functions
public void setConfidence(double c) {
confidence = c;
final double b = 99;
final double upperLimit = doLeft(b);
for (double a = 0; a <= 3; a += 0.01) {
double sum = upperLimit - doLeft(a);
sum = 1.0 - sum;
if (sum >= c) {
z = a;
break;
}
}
setZ(z);
}
private double doLeft(double z) {
if (z < -6.5) {
return 0;
}
if (z > 6.5) {
return 1;
}
long factK = 1;
double sum = 0;
double term = 1;
int k = 0;
while (Math.abs(term) > Math.exp(-23)) {
term = 0.3989422804 * Math.pow(-1, k) * Math.pow(z, k) / (2 * k + 1) / Math.pow(2, k) * Math.pow(z, k + 1) / factK;
sum += term;
k++;
factK *= k;
}
sum += 1 / 2;
if (sum < 1e-9) {
sum = 0;
}
return sum;
}
private void setZ(double z) {
this.z = z;
if (!isLeaf() && childs != null) {
for (Pair<Attribute, C45> c : childs) {
c.second.z = z;
c.second.setZ(z);
}
}
}
private double confidenceError(double invN, double f) {
final double z2 = z * z;
final double e = (f + (z2 * invN * 0.5) + z * Math.sqrt((f * invN) - (f * f * invN) + (z2 * invN * invN * 0.25))) / (1 + (z2 * invN));
return e;
}
/**
* Load a new C45 tree from the XML node root.
*
* @param root {@code root}
* @return new C45 tree from the XML node root
*/
protected C45 load(Node root) {
if (root.getNodeName().equals("node")) {
Pair<Attribute, C45> childs[] = new Pair[Integer.parseInt(root.getAttributes().getNamedItem("splits").getTextContent())];
NodeList aux = root.getChildNodes();
int currentChild = 0;
for (int i = 0, n = aux.getLength(); i < n; i++) {
Node current = aux.item(i);
if (!current.getNodeName().equals("split")) {
continue;
}
Attribute att = Attribute.load(current);
for (; i < n; i++) {
if ((current = aux.item(i)).getNodeName().equals("leaf")
|| current.getNodeName().equals("node")) {
break;
}
}
childs[currentChild++] = new Pair<>(att, load(current));
}
return new C45(childs);
} else if (root.getNodeName().equals("leaf")) {
return new C45(Attribute.load(root));
}
return null;
}
public boolean save(File path) {
try (FileOutputStream fos = new FileOutputStream(path);
PrintStream out = new PrintStream(fos)) {
out.println("<?xml version=\"1.0\" encoding=\"utf-8\"?>");
out.println("<" + getClass().getSimpleName() + ">");
save(out, "\t");
out.println("</" + getClass().getSimpleName() + ">");
out.close();
//safe format into a XML file.
return true;
} catch (Exception e) {
e.printStackTrace();
return false;
}
}
private void save(PrintStream out, String indent) throws IOException {
if (isLeaf()) {
out.println(indent + "<leaf type=\"" + output.getClass().getName() + "\" name=\"" + output.getName() + "\"><![CDATA[" + output.getValue() + "]]></leaf>");
} else {
out.println(indent + "<node splits=\"" + childs.length + "\">");
for (Pair<Attribute, C45> p : childs) {
out.println(indent + "\t<split type=\"" + p.first.getClass().getName() + "\" name=\"" + p.first.getName() + "\"><![CDATA[" + p.first.getValue() + "]]></split>");
p.second.save(out, indent + "\t");
}
out.println(indent + "</node>");
}
}
/**
* Print the tree over the standard output. Alias for <code>print("")</code>
*/
public void print() {
print("");
}
//end quinlan's
//IO functions
/**
* Print the tree over the standard output using an initial indent string.
* With each new level, an \t is appended to the indent string.
*
* @params indent Initial string for indentation.
*/
private void print(String indent) {
if (isLeaf()) {
System.out.println(indent + "[" + output + " " + samplesFreq + " e: " + error + "]");
} else {
for (Pair<Attribute, C45> p : childs) {
if (p.first.isCategorical()) {
System.out.println(indent + "[" + p.first.getName() + " = " + ((DiscreteAttribute) p.first).getValue() + " " + samplesFreq + " e: " + error + " be: " + backedUpError + "]");
} else {
System.out.println(indent + "[" + p.first.getName() + (childs[0] == p ? " <= " : " > ") + ((ContinuousAttribute) p.first).getValue() + " " + samplesFreq + " be: " + backedUpError + "]");
}
p.second.print(indent + "\t");
}
}
}
/**
* Dummy function, just needed to be able to use the Pair structure.
*
* @param o the other object
* @return always 0.
*/
@Override
public int compareTo(C45 o) {
return 0;
}
public static class EntropyInformation {
}
public static class DiscreteEntropyInformation extends EntropyInformation {
public double maxInfo;
public double maxSplitInfo;
}
public static class ContinuousEntropyInformation extends DiscreteEntropyInformation {
public double splitValue;
public int indexOfSplitValue;
}
//end IO functions
public static class GainInformation extends ContinuousEntropyInformation {
public double gain;
public double ratio;
}
}