MySQLDataSet.java
/*
* MIT License
*
* Copyright (c) 2009-2016 Ignacio Calderon <https://github.com/kronenthaler>
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
package libai.classifiers.dataset;
import libai.classifiers.Attribute;
import libai.common.Triplet;
import java.sql.*;
import java.util.*;
/**
* @author kronenthaler
*/
public class MySQLDataSet implements DataSet {
private int outputIndex;
private int itemCount = -1;
private String tableName;
private String rootName;
private int orderBy;
private Connection connection;
private ResultSetMetaData rsMetaData;
private Set<Attribute> classes = new HashSet<>();
private HashMap<Triplet<Integer, Integer, Integer>, HashMap<Attribute, Integer>> cacheFrequencies;
private MetaData metadata = new MetaData() {
@Override
public boolean isCategorical(int fieldIndex) {
try {
String type = rsMetaData.getColumnClassName(fieldIndex + 1);
return "java.lang.String".equals(type);
} catch (SQLException ex) {
return false;
}
}
@Override
public int getAttributeCount() {
try {
return rsMetaData.getColumnCount();
} catch (SQLException ex) {
return 0;
}
}
@Override
public Set<Attribute> getClasses() {
if (classes.isEmpty()) {
initializeClasses();
}
return classes;
}
@Override
public String getAttributeName(int fieldIndex) {
try {
return rsMetaData.getColumnName(fieldIndex + 1);
} catch (SQLException e) {
return "[" + fieldIndex + "]";
}
}
};
private MySQLDataSet(int output) {
outputIndex = output;
orderBy = output;
cacheFrequencies = new HashMap<>();
}
private MySQLDataSet(MySQLDataSet parent, int lo, int hi) {
this(parent.outputIndex);
connection = parent.connection;
this.orderBy = parent.orderBy;
this.tableName = parent.rootName + System.currentTimeMillis();
this.rootName = parent.rootName;
this.rsMetaData = parent.rsMetaData;
try (PreparedStatement stmt = connection.prepareStatement(
String.format("CREATE VIEW `%s` AS SELECT * FROM `%s` ORDER BY `%s`, `%s` LIMIT ?,?",
this.tableName,
parent.tableName,
parent.metadata.getAttributeName(orderBy),
parent.metadata.getAttributeName(outputIndex)),
ResultSet.TYPE_SCROLL_INSENSITIVE,
ResultSet.CONCUR_READ_ONLY,
ResultSet.CLOSE_CURSORS_AT_COMMIT)) {
stmt.setInt(1, lo);
stmt.setInt(2, hi - lo);
stmt.executeUpdate();
initializeClasses();
} catch (SQLException e) {
e.printStackTrace();
}
}
public MySQLDataSet(Connection conn, String tableName, int output) {
this(output);
connection = conn;
this.tableName = tableName;
this.rootName = tableName;
try (PreparedStatement stmt = conn.prepareStatement(
String.format("SELECT * FROM `%s`",
tableName),
ResultSet.TYPE_SCROLL_INSENSITIVE,
ResultSet.CONCUR_READ_ONLY,
ResultSet.CLOSE_CURSORS_AT_COMMIT)) {
try (ResultSet rs = stmt.executeQuery()) {
rsMetaData = rs.getMetaData();
}
} catch (Exception e) {
e.printStackTrace();
}
}
@Override
public DataSet getSubset(int lo, int hi) {
return new MySQLDataSet(this, lo, hi);
}
@Override
public int getOutputIndex() {
return outputIndex;
}
@Override
public int getItemsCount() {
if (itemCount >= 0) {
return itemCount;
}
try (PreparedStatement stmt = connection.prepareStatement(
String.format("SELECT COUNT(*) FROM `%s`", tableName));
ResultSet rs = stmt.executeQuery()) {
if (rs.next()) {
return itemCount = rs.getInt(1);
}
} catch (SQLException e) {
e.printStackTrace();
}
return 0;
}
@Override
public MetaData getMetaData() {
return metadata;
}
@Override
public Iterable<List<Attribute>> sortOver(final int fieldIndex) {
return sortOver(0, getItemsCount(), fieldIndex);
}
@Override
public Iterable<List<Attribute>> sortOver(final int lo, final int hi, final int fieldIndex) {
orderBy = fieldIndex;
return new Iterable<List<Attribute>>() {
@Override
public Iterator<List<Attribute>> iterator() {
String query = String.format("SELECT * FROM `%s` ORDER BY `%s`, `%s` LIMIT %d, %d",
tableName,
metadata.getAttributeName(fieldIndex),
metadata.getAttributeName(outputIndex),
lo, hi - lo);
try (PreparedStatement stmt = connection.prepareStatement(query)) {
return buildIterator(stmt.executeQuery(), hi - lo);
} catch (SQLException e) {
e.printStackTrace();
return null;
}
}
};
}
/* TODO change the implementation to return datasets from the same type */
@Override
public DataSet[] splitKeepingRelation(double proportion) {
long seed = System.currentTimeMillis();
String nameA = tableName + "_a_" + seed;
String nameB = tableName + "_b_" + seed;
try (PreparedStatement stmt = connection.prepareStatement(String.format(
"CREATE TABLE `%s` select * from `%s` limit 0", nameA, tableName))) {
stmt.executeUpdate();
try(PreparedStatement stmt2 = connection.prepareStatement(String.format(
"CREATE TABLE `%s` SELECT * FROM `%s` LIMIT 0", nameB, tableName))) {
stmt2.executeUpdate();
}
MySQLDataSet a = new MySQLDataSet(connection, nameA, outputIndex);
MySQLDataSet b = new MySQLDataSet(connection, nameB, outputIndex);
HashMap<Attribute, Integer> freq = getFrequencies(0, getItemsCount(), outputIndex);
for (Attribute output : freq.keySet()) {
String baseQuery = "INSERT INTO `%s` SELECT * FROM `%s` WHERE `%s` = '%s' ORDER BY RAND(%d) LIMIT %d, %d";
int size = (int) (freq.get(output) * proportion);
String aQuery = String.format(baseQuery, nameA, tableName, metadata.getAttributeName(outputIndex), output.getValue(), seed, 0, size);
String bQuery = String.format(baseQuery, nameB, tableName, metadata.getAttributeName(outputIndex), output.getValue(), seed, size, getItemsCount());
try (Statement s = connection.createStatement()) {
s.executeUpdate(aQuery);
s.executeUpdate(bQuery);
}
}
return new DataSet[]{a, b};
} catch (SQLException ex) {
ex.printStackTrace();
}
return null;
}
@Override
public Iterator<List<Attribute>> iterator() {
Iterator<List<Attribute>> result = null;
try (PreparedStatement stmt = connection.prepareStatement(
String.format("SELECT * FROM `%s`",
tableName))) {
result = buildIterator(stmt.executeQuery(), getItemsCount());
} catch (SQLException ex) {
}
return result;
}
public void clean() {
try (PreparedStatement stmt = connection.prepareStatement(
String.format("DROP VIEW IF EXISTS `%s`", tableName))) {
stmt.executeUpdate();
} catch (SQLException ex) {
ex.printStackTrace();
}
}
private void initializeClasses() {
try (PreparedStatement stmt = connection.prepareStatement(
String.format("SELECT DISTINCT(`%s`) FROM `%s`",
metadata.getAttributeName(outputIndex),
tableName));
ResultSet rs = stmt.executeQuery()) {
while (rs.next()) {
classes.add(Attribute.getInstance(rs.getString(1), metadata.getAttributeName(outputIndex)));
}
} catch (SQLException e) {
e.printStackTrace();
}
}
//TODO: fix this iterator so it can use the next as next and the the hasnext to just check (without side effects)
private Iterator<List<Attribute>> buildIterator(final ResultSet rs, final int itemsCount) {
return new Iterator<List<Attribute>>() {
int size = itemsCount;
@Override
public boolean hasNext() {
return size > 0;
}
@Override
public List<Attribute> next() {
try {
if (rs.next()) {
size--;
List<Attribute> record = new ArrayList<>();
for (int i = 0; i < metadata.getAttributeCount(); i++) {
String fieldName = metadata.getAttributeName(i);
record.add(Attribute.getInstance(rs.getString(fieldName), fieldName));
}
return record;
} else {
return null;
}
} catch (SQLException e) {
return null;
}
}
@Override
public void remove() {
throw new UnsupportedOperationException("Not supported yet.");
}
};
}
@Override
public boolean allTheSameOutput() {
return metadata.getClasses().size() == 1;
}
@Override
public Attribute allTheSame() {
try {
StringBuffer attributes = new StringBuffer();
for (int i = 0; i < metadata.getAttributeCount(); i++) {
if (i != outputIndex) {
if (i > 0 && outputIndex != 0) {
attributes.append(',');
}
attributes.append(metadata.getAttributeName(i));
}
}
String query = String.format("SELECT %s, count(*) as size FROM %s group by %s", attributes, tableName, attributes);
PreparedStatement stmt = connection.prepareStatement(query);
ResultSet rs = stmt.executeQuery();
if (rs.next()) {
if (rs.getInt("size") != getItemsCount()) {
return null;
} else {
String fieldName = metadata.getAttributeName(outputIndex);
String query2 = String.format("SELECT %s, count(*) as count FROM %s GROUP BY %s ORDER BY count DESC LIMIT 1",
fieldName, tableName, fieldName);
stmt = connection.prepareStatement(query2);
rs = stmt.executeQuery();
if (rs.next()) {
return Attribute.getInstance(rs.getString(1), fieldName);
}
}
}
} catch (SQLException e) {
}
return null;
}
@Override
public HashMap<Attribute, Integer> getFrequencies(int lo, int hi, int fieldIndex) {
Triplet<Integer, Integer, Integer> key = new Triplet<>(lo, hi, fieldIndex);
if (cacheFrequencies.get(key) != null) {
return cacheFrequencies.get(key);
}
if (!metadata.isCategorical(fieldIndex)) {
throw new IllegalArgumentException("The attribute must be discrete");
}
HashMap<Attribute, Integer> freq = new HashMap<>();
String fieldName = metadata.getAttributeName(fieldIndex);
String query = String.format("SELECT `%s`, count(*) as count FROM (SELECT `%s` FROM `%s` ORDER BY `%s`,`%s` LIMIT %d,%d) as tmp GROUP BY `%s`",
fieldName, fieldName, tableName, metadata.getAttributeName(orderBy), metadata.getAttributeName(outputIndex), lo, (hi - lo), fieldName);
try (PreparedStatement stmt = connection.prepareStatement(query);
ResultSet rs = stmt.executeQuery()) {
while (rs.next()) {
freq.put(Attribute.getInstance(rs.getString(fieldName), fieldName), rs.getInt("count"));
}
} catch (SQLException e) {
e.printStackTrace();
}
cacheFrequencies.put(key, freq);
return freq;
}
@Override
public String toString() {
Iterable<List<Attribute>> r = sortOver(orderBy);
StringBuilder str = new StringBuilder();
for (List<Attribute> l : r) {
str.append(l.toString()).append("\n");
}
return str.toString();
}
@Override
public void close() {
try (PreparedStatement stmt = connection.prepareStatement(String.format("DROP VIEW `%s`", tableName))) {
stmt.executeUpdate();
} catch (SQLException ex) {
try (PreparedStatement stmt = connection.prepareStatement(String.format("DROP TABLE `%s`", tableName))) {
stmt.executeUpdate();
} catch (SQLException ex2) {
ex2.printStackTrace();
}
}
}
}