/*
 * Decompiled with CFR 0.152.
 */
package jsat.io;

import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.EOFException;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.EnumSet;
import java.util.Iterator;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.DataSet;
import jsat.SimpleDataSet;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.DataPoint;
import jsat.io.DataWriter;
import jsat.linear.DenseVector;
import jsat.linear.IndexValue;
import jsat.linear.SparseVector;
import jsat.linear.Vec;
import jsat.regression.RegressionDataSet;

public class JSATData {
    public static final byte[] MAGIC_NUMBER = new byte[]{74, 83, 65, 84, 95, 48, 48};
    public static final byte STRING_ENCODING_ASCII = 0;
    public static final byte STRING_ENCODING_UTF_16 = 1;

    private JSATData() {
    }

    public static <Type extends DataSet<Type>> void writeData(DataSet<Type> dataset, OutputStream outRaw) throws IOException {
        JSATData.writeData(dataset, outRaw, FloatStorageMethod.AUTO);
    }

    public static <Type extends DataSet<Type>> void writeData(DataSet<Type> dataset, OutputStream outRaw, FloatStorageMethod fpStore) throws IOException {
        CategoricalData predicting;
        DataWriter.DataSetType type;
        fpStore = FloatStorageMethod.getMethod(dataset, fpStore);
        if (dataset instanceof ClassificationDataSet) {
            type = DataWriter.DataSetType.CLASSIFICATION;
            predicting = ((ClassificationDataSet)dataset).getPredicting();
        } else if (dataset instanceof RegressionDataSet) {
            type = DataWriter.DataSetType.REGRESSION;
            predicting = null;
        } else {
            type = DataWriter.DataSetType.SIMPLE;
            predicting = null;
        }
        DataWriter dw = JSATData.getWriter(outRaw, dataset.getCategories(), dataset.getNumNumericalVars(), predicting, fpStore, type);
        for (int i = 0; i < dataset.getSampleSize(); ++i) {
            double label = 0.0;
            if (dataset instanceof ClassificationDataSet) {
                label = ((ClassificationDataSet)dataset).getDataPointCategory(i);
            } else if (dataset instanceof RegressionDataSet) {
                label = ((RegressionDataSet)dataset).getTargetValue(i);
            }
            dw.writePoint(dataset.getDataPoint(i), label);
        }
        dw.finish();
        outRaw.flush();
    }

    public static DataWriter getWriter(OutputStream out, CategoricalData[] catInfo, int dim, final CategoricalData predicting, final FloatStorageMethod fpStore, DataWriter.DataSetType type) throws IOException {
        return new DataWriter(out, catInfo, dim, type){

            @Override
            protected void writeHeader(CategoricalData[] catInfo, int dim, DataWriter.DataSetType type, OutputStream out) throws IOException {
                DataOutputStream data_out = new DataOutputStream(out);
                data_out.write(MAGIC_NUMBER);
                int numNumeric = dim;
                int numCat = catInfo.length;
                DatasetTypeMarker marker = DatasetTypeMarker.STANDARD;
                if (type == DataWriter.DataSetType.REGRESSION) {
                    ++numNumeric;
                    marker = DatasetTypeMarker.REGRESSION;
                }
                if (type == DataWriter.DataSetType.CLASSIFICATION) {
                    ++numCat;
                    marker = DatasetTypeMarker.CLASSIFICATION;
                }
                data_out.writeByte(marker.ordinal());
                data_out.writeByte(fpStore.ordinal());
                data_out.writeInt(numNumeric);
                data_out.writeInt(numCat);
                data_out.writeInt(-1);
                for (CategoricalData category : catInfo) {
                    JSATData.writeString(category.getCategoryName(), data_out);
                    data_out.writeInt(category.getNumOfCategories());
                    for (int i = 0; i < category.getNumOfCategories(); ++i) {
                        JSATData.writeString(category.getOptionName(i), data_out);
                    }
                }
                if (type == DataWriter.DataSetType.CLASSIFICATION) {
                    CategoricalData category = predicting;
                    JSATData.writeString(category.getCategoryName(), data_out);
                    data_out.writeInt(category.getNumOfCategories());
                    for (int i = 0; i < category.getNumOfCategories(); ++i) {
                        JSATData.writeString(category.getOptionName(i), data_out);
                    }
                }
                data_out.flush();
            }

            @Override
            protected void pointToBytes(DataPoint dp, double label, ByteArrayOutputStream byteOut) {
                try {
                    DataOutputStream data_out = new DataOutputStream(byteOut);
                    fpStore.writeFP(dp.getWeight(), data_out);
                    for (int val : dp.getCategoricalValues()) {
                        data_out.writeInt(val);
                    }
                    if (this.type == DataWriter.DataSetType.CLASSIFICATION) {
                        data_out.writeInt((int)label);
                    }
                    Vec numericVals = dp.getNumericalValues();
                    data_out.writeBoolean(numericVals.isSparse());
                    if (numericVals.isSparse()) {
                        if (this.type == DataWriter.DataSetType.REGRESSION) {
                            data_out.writeInt(numericVals.nnz() + 1);
                        } else {
                            data_out.writeInt(numericVals.nnz());
                        }
                        for (IndexValue iv : numericVals) {
                            data_out.writeInt(iv.getIndex());
                            fpStore.writeFP(iv.getValue(), data_out);
                        }
                    } else {
                        for (int j = 0; j < numericVals.length(); ++j) {
                            fpStore.writeFP(numericVals.get(j), data_out);
                        }
                    }
                    if (this.type == DataWriter.DataSetType.REGRESSION) {
                        if (numericVals.isSparse()) {
                            data_out.writeInt(numericVals.length());
                        }
                        fpStore.writeFP(label, data_out);
                    }
                    data_out.flush();
                }
                catch (IOException ex) {
                    Logger.getLogger(JSATData.class.getName()).log(Level.SEVERE, null, ex);
                }
            }
        };
    }

    public static DataSet<?> load(InputStream inRaw) throws IOException {
        return JSATData.load(inRaw, false);
    }

    public static SimpleDataSet loadSimple(InputStream inRaw) throws IOException {
        return (SimpleDataSet)JSATData.load(inRaw, true);
    }

    public static ClassificationDataSet loadClassification(InputStream inRaw) throws IOException {
        return (ClassificationDataSet)JSATData.load(inRaw);
    }

    public static RegressionDataSet loadRegression(InputStream inRaw) throws IOException {
        return (RegressionDataSet)JSATData.load(inRaw);
    }

    protected static DataSet<?> load(InputStream inRaw, boolean forceAsStandard) throws IOException {
        DataSet data;
        DataInputStream in = new DataInputStream(inRaw);
        byte[] magic_number = new byte[MAGIC_NUMBER.length];
        in.readFully(magic_number);
        String magic = new String(magic_number, "US-ASCII");
        if (!magic.startsWith("JSAT_")) {
            throw new RuntimeException("data does not contain magic number");
        }
        DatasetTypeMarker marker = DatasetTypeMarker.values()[in.readByte()];
        FloatStorageMethod fpStore = FloatStorageMethod.values()[in.readByte()];
        int numNumeric = in.readInt();
        int numCat = in.readInt();
        int N = in.readInt();
        if (forceAsStandard) {
            marker = DatasetTypeMarker.STANDARD;
        }
        if (marker == DatasetTypeMarker.CLASSIFICATION) {
            --numCat;
        } else if (marker == DatasetTypeMarker.REGRESSION) {
            --numNumeric;
        }
        CategoricalData[] categories = new CategoricalData[numCat];
        CategoricalData predicting = null;
        for (int i = 0; i < categories.length; ++i) {
            String name = JSATData.readString(in);
            int k = in.readInt();
            categories[i] = new CategoricalData(k);
            categories[i].setCategoryName(name);
            for (int j = 0; j < k; ++j) {
                categories[i].setOptionName(JSATData.readString(in), j);
            }
        }
        if (marker == DatasetTypeMarker.CLASSIFICATION) {
            String name = JSATData.readString(in);
            int k = in.readInt();
            predicting = new CategoricalData(k);
            predicting.setCategoryName(name);
            for (int j = 0; j < k; ++j) {
                predicting.setOptionName(JSATData.readString(in), j);
            }
        }
        switch (marker) {
            case CLASSIFICATION: {
                data = new ClassificationDataSet(numNumeric, categories, predicting);
                break;
            }
            case REGRESSION: {
                data = new RegressionDataSet(numNumeric, categories);
                break;
            }
            default: {
                data = new SimpleDataSet(categories, numNumeric);
            }
        }
        if (N < 0) {
            N = Integer.MAX_VALUE;
        }
        try {
            block13: for (int i = 0; i < N; ++i) {
                Vec numericVals;
                boolean sparse;
                double weight = fpStore.readFP(in);
                int[] catVals = new int[numCat];
                double target = 0.0;
                for (int j = 0; j < catVals.length; ++j) {
                    catVals[j] = in.readInt();
                }
                if (marker == DatasetTypeMarker.CLASSIFICATION) {
                    target = in.readInt();
                }
                if (sparse = in.readBoolean()) {
                    int nnz = in.readInt();
                    if (marker == DatasetTypeMarker.REGRESSION) {
                        --nnz;
                    }
                    int[] indicies = new int[nnz];
                    double[] values = new double[nnz];
                    for (int j = 0; j < nnz; ++j) {
                        indicies[j] = in.readInt();
                        values[j] = fpStore.readFP(in);
                    }
                    numericVals = new SparseVector(indicies, values, numNumeric, nnz);
                } else {
                    numericVals = new DenseVector(numNumeric);
                    for (int j = 0; j < numNumeric; ++j) {
                        ((Vec)numericVals).set(j, fpStore.readFP(in));
                    }
                }
                if (marker == DatasetTypeMarker.REGRESSION) {
                    if (((Vec)numericVals).isSparse()) {
                        in.readInt();
                    }
                    target = fpStore.readFP(in);
                }
                DataPoint dp = new DataPoint(numericVals, catVals, categories, weight);
                switch (marker) {
                    case CLASSIFICATION: {
                        ((ClassificationDataSet)data).addDataPoint(dp, (int)target);
                        continue block13;
                    }
                    case REGRESSION: {
                        ((RegressionDataSet)data).addDataPoint(dp, target);
                        continue block13;
                    }
                    default: {
                        ((SimpleDataSet)data).add(dp);
                    }
                }
            }
        }
        catch (EOFException eo) {
            // empty catch block
        }
        in.close();
        return data;
    }

    private static void writeString(String s, DataOutputStream out) throws IOException {
        int i;
        boolean isAscii = true;
        for (i = 0; i < s.length() && isAscii; ++i) {
            if (s.charAt(i) < '\u0100' && s.charAt(i) > '\u0000') continue;
            isAscii = false;
        }
        if (isAscii) {
            out.writeByte(0);
            out.writeInt(s.length());
            for (i = 0; i < s.length(); ++i) {
                out.writeByte(s.charAt(i));
            }
        } else {
            byte[] bytes = s.getBytes("UTF-16");
            out.writeByte(1);
            out.writeInt(bytes.length);
            out.write(bytes);
        }
    }

    private static String readString(DataInputStream in) throws IOException {
        StringBuilder builder = new StringBuilder();
        byte encoding = in.readByte();
        int bytesToRead = in.readInt();
        switch (encoding) {
            case 0: {
                for (int i = 0; i < bytesToRead; ++i) {
                    builder.append(Character.toChars(in.readByte()));
                }
                return builder.toString();
            }
            case 1: {
                byte[] bytes = new byte[bytesToRead];
                in.readFully(bytes);
                return new String(bytes, "UTF-16");
            }
        }
        throw new RuntimeException("Unkown string encoding value " + encoding);
    }

    public static enum FloatStorageMethod {
        AUTO{

            @Override
            protected void writeFP(double value, DataOutputStream out) throws IOException {
                throw new UnsupportedOperationException("Not supported .");
            }

            @Override
            protected double readFP(DataInputStream in) throws IOException {
                throw new UnsupportedOperationException("Not supported .");
            }

            @Override
            protected boolean noLoss(double orig) {
                return true;
            }
        }
        ,
        FP64{

            @Override
            protected void writeFP(double value, DataOutputStream out) throws IOException {
                out.writeDouble(value);
            }

            @Override
            protected double readFP(DataInputStream in) throws IOException {
                return in.readDouble();
            }

            @Override
            protected boolean noLoss(double orig) {
                return true;
            }
        }
        ,
        FP32{

            @Override
            protected void writeFP(double value, DataOutputStream out) throws IOException {
                out.writeFloat((float)value);
            }

            @Override
            protected double readFP(DataInputStream in) throws IOException {
                return in.readFloat();
            }

            @Override
            protected boolean noLoss(double orig) {
                float f_o = (float)orig;
                return Double.valueOf(f_o) - orig == 0.0;
            }
        }
        ,
        SHORT{

            @Override
            protected void writeFP(double value, DataOutputStream out) throws IOException {
                out.writeShort(Math.min(Math.max((int)value, Short.MIN_VALUE), Short.MAX_VALUE));
            }

            @Override
            protected double readFP(DataInputStream in) throws IOException {
                return in.readShort();
            }

            @Override
            protected boolean noLoss(double orig) {
                return -32768.0 <= orig && orig <= 32767.0 && orig == Math.rint(orig);
            }
        }
        ,
        BYTE{

            @Override
            protected void writeFP(double value, DataOutputStream out) throws IOException {
                out.writeByte(Math.min(Math.max((int)value, -128), 127));
            }

            @Override
            protected double readFP(DataInputStream in) throws IOException {
                return in.readByte();
            }

            @Override
            protected boolean noLoss(double orig) {
                return -128.0 <= orig && orig <= 127.0 && orig == Math.rint(orig);
            }
        }
        ,
        U_BYTE{

            @Override
            protected void writeFP(double value, DataOutputStream out) throws IOException {
                out.writeByte(Math.min(Math.max((int)value, 0), 255));
            }

            @Override
            protected double readFP(DataInputStream in) throws IOException {
                return in.readByte() & 0xFF;
            }

            @Override
            protected boolean noLoss(double orig) {
                return 0.0 <= orig && orig <= 255.0 && orig == Math.rint(orig);
            }
        };


        protected abstract void writeFP(double var1, DataOutputStream var3) throws IOException;

        protected abstract double readFP(DataInputStream var1) throws IOException;

        protected abstract boolean noLoss(double var1);

        public static <Type extends DataSet<Type>> FloatStorageMethod getMethod(DataSet<Type> data, FloatStorageMethod method) {
            if (method == AUTO) {
                Iterator iter;
                EnumSet<FloatStorageMethod> storageCandidates = EnumSet.complementOf(EnumSet.of(AUTO));
                for (int i = 0; i < data.getSampleSize(); ++i) {
                    DataPoint dp = data.getDataPoint(i);
                    for (IndexValue iv : dp.getNumericalValues()) {
                        Iterator iter2 = storageCandidates.iterator();
                        while (iter2.hasNext()) {
                            if (((FloatStorageMethod)((Object)iter2.next())).noLoss(iv.getValue())) continue;
                            iter2.remove();
                        }
                        if (storageCandidates.size() != 1) continue;
                        break;
                    }
                    iter = storageCandidates.iterator();
                    while (iter.hasNext()) {
                        if (((FloatStorageMethod)((Object)iter.next())).noLoss(dp.getWeight())) continue;
                        iter.remove();
                    }
                    if (storageCandidates.size() == 1) break;
                }
                if (data instanceof RegressionDataSet) {
                    for (IndexValue iv : ((RegressionDataSet)data).getTargetValues()) {
                        iter = storageCandidates.iterator();
                        while (iter.hasNext()) {
                            if (((FloatStorageMethod)((Object)iter.next())).noLoss(iv.getValue())) continue;
                            iter.remove();
                        }
                        if (storageCandidates.size() != 1) continue;
                        break;
                    }
                }
                if (storageCandidates.contains((Object)BYTE)) {
                    return BYTE;
                }
                if (storageCandidates.contains((Object)U_BYTE)) {
                    return U_BYTE;
                }
                if (storageCandidates.contains((Object)SHORT)) {
                    return SHORT;
                }
                if (storageCandidates.contains((Object)FP32)) {
                    return FP32;
                }
                return FP64;
            }
            return method;
        }
    }

    public static enum DatasetTypeMarker {
        STANDARD,
        REGRESSION,
        CLASSIFICATION;

    }
}

