NaiveBayes.java
/*
* MIT License
*
* Copyright (c) 2009-2016 Ignacio Calderon <https://github.com/kronenthaler>
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
package libai.classifiers.bayes;
import libai.classifiers.Attribute;
import libai.classifiers.ContinuousAttribute;
import libai.classifiers.DiscreteAttribute;
import libai.classifiers.dataset.DataSet;
import libai.classifiers.dataset.MetaData;
import libai.common.Pair;
import org.w3c.dom.Document;
import org.w3c.dom.Node;
import org.w3c.dom.NodeList;
import javax.xml.parsers.DocumentBuilder;
import javax.xml.parsers.DocumentBuilderFactory;
import java.io.*;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import javax.xml.parsers.ParserConfigurationException;
import org.xml.sax.SAXException;
/**
* @author kronenthaler
*/
public class NaiveBayes {
protected int outputIndex;
protected int totalCount;
protected MetaData metadata;
protected HashMap<Attribute, Object[]> params;
//Factories
public static NaiveBayes getInstance(File path) {
try {
DocumentBuilderFactory dbf = DocumentBuilderFactory.newInstance();
DocumentBuilder db = dbf.newDocumentBuilder();
Document doc;
try (FileInputStream fis = new FileInputStream(path)) {
doc = db.parse(fis);
}
Node root = doc.getElementsByTagName("NaiveBayes").item(0);
return new NaiveBayes().load(root);
} catch (IOException |
ParserConfigurationException |
SAXException e) {
e.printStackTrace();
return null;
}
}
public static NaiveBayes getInstance(DataSet ds) {
return new NaiveBayes().train(ds);
}
public NaiveBayes train(DataSet ds) {
outputIndex = ds.getOutputIndex();
totalCount = ds.getItemsCount();
metadata = ds.getMetaData();
params = new HashMap<>();
initialize(ds);
precalculate(ds);
return this;
}
private void initialize(DataSet ds) {
int attributeCount = metadata.getAttributeCount();
for (Attribute c : metadata.getClasses()) {
params.put(c, new Object[attributeCount]);
for (int j = 0; j < attributeCount; j++) {
if (j == outputIndex) {
params.get(c)[j] = (Integer) 0;
} else if (metadata.isCategorical(j)) {
params.get(c)[j] = new HashMap<>(); //value, count
} else {
params.get(c)[j] = new Pair<>(0.0, 0.0);//mean, sd
}
}
}
}
private void precalculate(DataSet ds) {
for (List<Attribute> record : ds) {
Attribute outputAttr = record.get(outputIndex);
int j = 0;
for (Attribute attr : record) {
Object value = attr.getValue();
if (j == outputIndex) {
//count simple frequencies of the output class
int current = (Integer) params.get(outputAttr)[j];
params.get(outputAttr)[j] = current + 1;
} else if (metadata.isCategorical(j)) {
// count frequencies of each different values in this attribute
HashMap<String, Integer> freq = (HashMap<String, Integer>) params.get(outputAttr)[j];
if (freq.get((String) value) == null) {
freq.put((String) value, 0);
}
freq.put((String) value, freq.get((String) value) + 1);
} else {
// precalculate the mean and standard deviation, acumulate part.
Pair<Double, Double> acum = (Pair<Double, Double>) params.get(outputAttr)[j];
//acum for mean and SD
acum.first = acum.first + (Double) value;
acum.second = acum.second + Math.pow((Double) value, 2);
}
j++;
}
}
// just for the continuous attributes, finish the calculation of the
// gausian parameters.
// for each class values
for (Object[] data : params.values()) {
// for each look up table
for (Object o : data) {
// look for the continuos attributes, that are not the output
if (o instanceof Pair) {
Pair<Double, Double> acum = (Pair<Double, Double>) o;
double sd = acum.second;
double mean = acum.first;
double length = (double) ((Integer) data[outputIndex]);
sd = (sd - ((mean * mean) / (double) length));
acum.second = sd / (double) (length - 1);
acum.first = mean / (double) length;
}
}
}
}
//calculate the maximum posterior probability this data record (x) in the data set
//P(Ci|x) > P(Cj|x) 1 <= j < m, i!=j
public Attribute eval(List<Attribute> x) {
Attribute winner = null;
double max = -Double.MAX_VALUE;
for (Attribute c : params.keySet()) {
double tmp = probability(c, x);
if (tmp > max) {
max = tmp;
winner = c;
}
}
return winner;
}
//P(H|x) = P(x|H)P(H) / P(x)
//relaxed calculation of P(H|x). the exact value is not necessary, just to know which class
//has the highest value.
private double probability(Attribute h, List<Attribute> x) {
return probability(x, h) * probability(h);
}
private double probability(List<Attribute> x, Attribute h) {
double p = 1;
//look for all records in ds with class h.
for (int k = 0, n = x.size(); k < n; k++) {
Attribute attr = x.get(k);
if (metadata.isCategorical(k)) {
p *= (count((DiscreteAttribute) attr, k, h) + 1) / (double) (((Integer) params.get(h)[outputIndex]) + 1);
} else {
p *= gaussian((ContinuousAttribute) attr, k, h);
}
}
return p;
}
//laplace's correction. x+1 / |d|+|c|
private double probability(Attribute h) {
return (((Integer) params.get(h)[outputIndex]) + 1) / (double) (totalCount + params.size());
}
private int count(DiscreteAttribute xk, int k, Attribute h) {
@SuppressWarnings("unchecked")
HashMap<String, Integer> freq = (HashMap<String, Integer>) params.get(h)[k];
return freq.get(xk.getValue());
}
private double gaussian(ContinuousAttribute xk, int k, Attribute h) {
@SuppressWarnings("unchecked")
Pair<Double, Double> ps = (Pair<Double, Double>) params.get(h)[k];
double mean = ps.first;
double sd = ps.second;
double x = xk.getValue();
return Math.exp(-(Math.pow(x - mean, 2) / (2 * sd))) * (1 / (Math.sqrt(2 * Math.PI * sd)));
}
//IO functions
public boolean save(File path) {
try (FileOutputStream fos = new FileOutputStream(path);
PrintStream out = new PrintStream(fos, true, "UTF-8")) {
out.println("<?xml version=\"1.0\" encoding=\"utf-8\"?>");
out.println("<" + getClass().getSimpleName() + " "
+ "outputIndex=\"" + outputIndex + "\" "
+ "totalCount=\"" + totalCount + "\" "
+ "attributes=\"" + metadata.getAttributeCount() + "\">");
save(out, "\t");
out.println("</" + getClass().getSimpleName() + ">");
//safe format into a XML file.
return true;
} catch (Exception e) {
e.printStackTrace();
return false;
}
}
private void save(PrintStream out, String indent) throws IOException {
for (Attribute c : params.keySet()) {
out.println(indent + "<class>");
out.println(indent + "\t<params type=\"" + c.getClass().getName() + "\" name=\"" + c.getName() + "\" ><![CDATA[" + c.getValue() + "]]></params>");
int i = 0;
for (Object o : params.get(c)) {
out.println(indent + "\t<attribute index=\"" + i + "\">");
if (o instanceof Integer) { //class count
out.println(indent + "\t\t<count>" + o + "</count>");
} else if (o instanceof Pair) { //continuous parameter
Pair<Double, Double> p = (Pair<Double, Double>) o;
out.println(indent + "\t\t<stats mean=\"" + p.first + "\" sd=\"" + p.second + "\"/>");
} else { //discrete parameter
HashMap<String, Integer> freq = (HashMap<String, Integer>) o;
for (Map.Entry<String, Integer> ent : freq.entrySet()) {
out.println(indent + "\t\t<item count=\"" + ent.getKey() + "\"><![CDATA[" + ent.getValue() + "]]></item>");
}
}
out.println(indent + "\t</attribute>");
i++;
}
out.println(indent + "</class>");
}
}
private NaiveBayes load(Node root) {
outputIndex = Integer.parseInt(root.getAttributes().getNamedItem("outputIndex").getTextContent());
totalCount = Integer.parseInt(root.getAttributes().getNamedItem("totalCount").getTextContent());
params = new HashMap<>();
int attributeCount = Integer.parseInt(root.getAttributes().getNamedItem("attributes").getTextContent());
NodeList children = root.getChildNodes();
for (int i = 0; i < children.getLength(); i++) {
Node clazz = children.item(i);
if (clazz.getNodeName().equals("class")) {
NodeList p = clazz.getChildNodes();
Attribute key = null;
int index = -1;
for (int j = 0; j < p.getLength(); j++) {
Node current = p.item(j);
if (current.getNodeName().equals("params")) {
key = Attribute.load(current);
} else if (current.getNodeName().equals("attribute")) {
index = Integer.parseInt(current.getAttributes().getNamedItem("index").getTextContent());
if (params.get(key) == null) {
params.put(key, new Object[attributeCount]);
}
params.get(key)[index] = getParams(current);
}
}
}
}
System.err.println(this);
return this;
}
private Object getParams(Node root) {
NodeList children = root.getChildNodes();
HashMap<String, Integer> freq = new HashMap<>();
for (int i = 0; i < children.getLength(); i++) {
Node current = children.item(i);
if (current.getNodeName().equals("count")) {
return (Integer) Integer.parseInt(current.getTextContent());
} else if (current.getNodeName().equals("stats")) {
double mean = Double.parseDouble(current.getAttributes().getNamedItem("mean").getTextContent());
double sd = Double.parseDouble(current.getAttributes().getNamedItem("sd").getTextContent());
return new Pair<>(mean, sd);
} else if (current.getNodeName().equals("item")) {
int count = Integer.parseInt(current.getAttributes().getNamedItem("count").getTextContent());
String key = current.getTextContent();
freq.put(key, count);
}
}
return freq;
}
}