/*
 * Decompiled with CFR 0.152.
 */
package jsat;

import java.lang.ref.SoftReference;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.DataPoint;
import jsat.datatransform.DataTransform;
import jsat.datatransform.InPlaceTransform;
import jsat.linear.ConstantVector;
import jsat.linear.DenseMatrix;
import jsat.linear.DenseVector;
import jsat.linear.IndexValue;
import jsat.linear.Matrix;
import jsat.linear.MatrixOfVecs;
import jsat.linear.MatrixStatistics;
import jsat.linear.SparseVector;
import jsat.linear.Vec;
import jsat.math.OnLineStatistics;
import jsat.utils.FakeExecutor;
import jsat.utils.IntList;
import jsat.utils.ListUtils;
import jsat.utils.SystemInfo;
import jsat.utils.random.XORWOW;

public abstract class DataSet<Type extends DataSet> {
    protected int numNumerVals;
    protected CategoricalData[] categories;
    protected List<String> numericalVariableNames;
    protected Map<Integer, SoftReference<Vec>> columnVecCache = new HashMap<Integer, SoftReference<Vec>>();

    public boolean setNumericName(String name, int i) {
        if (this.numericalVariableNames.contains(name = name.toLowerCase())) {
            return false;
        }
        if (i >= this.getNumNumericalVars() || i < 0) {
            return false;
        }
        this.numericalVariableNames.set(i, name);
        return true;
    }

    public String getNumericName(int i) {
        if (i < this.getNumNumericalVars() && i >= 0) {
            return this.numericalVariableNames == null ? null : this.numericalVariableNames.get(i);
        }
        throw new IndexOutOfBoundsException("Can not acces variable for invalid index  " + i);
    }

    public String getCategoryName(int i) {
        if (i < this.getNumCategoricalVars() && i >= 0) {
            return this.categories[i].getCategoryName();
        }
        throw new IndexOutOfBoundsException("Can not acces variable for invalid index  " + i);
    }

    public void applyTransform(DataTransform dt) {
        this.applyTransform(dt, false);
    }

    public void applyTransform(DataTransform dt, ExecutorService ex) {
        if (ex == null || ex instanceof FakeExecutor) {
            this.applyTransform(dt);
        } else {
            this.applyTransform(dt, false, ex);
        }
    }

    public void applyTransform(DataTransform dt, boolean mutate) {
        this.applyTransform(dt, mutate, new FakeExecutor());
    }

    public void applyTransform(final DataTransform dt, boolean mutate, ExecutorService ex) {
        final CountDownLatch latch = new CountDownLatch(SystemInfo.LogicalCores);
        if (ex == null) {
            ex = new FakeExecutor();
        }
        if (mutate && dt instanceof InPlaceTransform) {
            final InPlaceTransform ipt = (InPlaceTransform)dt;
            int id = 0;
            while (id < SystemInfo.LogicalCores) {
                final int ID = id++;
                ex.submit(new Runnable(){

                    @Override
                    public void run() {
                        for (int i = ID; i < DataSet.this.getSampleSize(); i += SystemInfo.LogicalCores) {
                            ipt.mutableTransform(DataSet.this.getDataPoint(i));
                        }
                        latch.countDown();
                    }
                });
            }
        } else {
            int id = 0;
            while (id < SystemInfo.LogicalCores) {
                final int ID = id++;
                ex.submit(new Runnable(){

                    @Override
                    public void run() {
                        for (int i = ID; i < DataSet.this.getSampleSize(); i += SystemInfo.LogicalCores) {
                            DataSet.this.setDataPoint(i, dt.transform(DataSet.this.getDataPoint(i)));
                        }
                        latch.countDown();
                    }
                });
            }
        }
        try {
            latch.await();
            this.columnVecCache.clear();
            this.numNumerVals = this.getDataPoint(0).numNumericalValues();
            this.categories = this.getDataPoint(0).getCategoricalData();
            if (this.numericalVariableNames != null) {
                this.numericalVariableNames.clear();
                for (int i = 0; i < this.getNumNumericalVars(); ++i) {
                    this.numericalVariableNames.add("TN" + (i + 1));
                }
            }
        }
        catch (InterruptedException ex1) {
            Logger.getLogger(DataSet.class.getName()).log(Level.SEVERE, null, ex1);
        }
    }

    public void replaceNumericFeatures(List<Vec> newNumericFeatures) {
        int i;
        if (this.getSampleSize() != newNumericFeatures.size()) {
            throw new RuntimeException("Input list does not have the same not of dataums as the dataset");
        }
        for (i = 0; i < newNumericFeatures.size(); ++i) {
            DataPoint dp_i = this.getDataPoint(i);
            this.setDataPoint(i, new DataPoint(newNumericFeatures.get(i), dp_i.getCategoricalValues(), dp_i.getCategoricalData(), dp_i.getWeight()));
        }
        this.numNumerVals = this.getDataPoint(0).numNumericalValues();
        if (this.numericalVariableNames != null) {
            this.numericalVariableNames.clear();
            for (i = 0; i < this.getNumNumericalVars(); ++i) {
                this.numericalVariableNames.add("TN" + (i + 1));
            }
        }
    }

    public abstract DataPoint getDataPoint(int var1);

    public abstract void setDataPoint(int var1, DataPoint var2);

    public OnLineStatistics[] getOnlineColumnStats(boolean useWeights) {
        OnLineStatistics[] stats = new OnLineStatistics[this.numNumerVals];
        for (int i = 0; i < stats.length; ++i) {
            stats[i] = new OnLineStatistics();
        }
        double totalSoW = 0.0;
        double[] nanWeight = new double[this.numNumerVals];
        Iterator<DataPoint> iter = this.getDataPointIterator();
        while (iter.hasNext()) {
            DataPoint dp = iter.next();
            double weight = useWeights ? dp.getWeight() : 1.0;
            totalSoW += weight;
            Vec v = dp.getNumericalValues();
            for (IndexValue iv : v) {
                if (Double.isNaN(iv.getValue())) {
                    int n = iv.getIndex();
                    nanWeight[n] = nanWeight[n] + weight;
                    continue;
                }
                stats[iv.getIndex()].add(iv.getValue(), weight);
            }
        }
        double expected = totalSoW;
        for (int i = 0; i < stats.length; ++i) {
            stats[i].add(0.0, expected - stats[i].getSumOfWeights() - nanWeight[i]);
        }
        return stats;
    }

    public OnLineStatistics getOnlineDenseStats() {
        OnLineStatistics stats = new OnLineStatistics();
        double N = this.getNumNumericalVars();
        for (int i = 0; i < this.getSampleSize(); ++i) {
            stats.add((double)this.getDataPoint(i).getNumericalValues().nnz() / N);
        }
        return stats;
    }

    public Vec[] getColumnMeanVariance() {
        int d = this.getNumNumericalVars();
        Vec[] vecs = new Vec[]{new DenseVector(d), new DenseVector(d)};
        Vec means = vecs[0];
        Vec stdDevs = vecs[1];
        MatrixStatistics.meanVector(means, this);
        MatrixStatistics.covarianceDiag(means, stdDevs, this);
        return vecs;
    }

    public Iterator<DataPoint> getDataPointIterator() {
        Iterator<DataPoint> iteData = new Iterator<DataPoint>(){
            int cur = 0;
            int to = DataSet.this.getSampleSize();

            @Override
            public boolean hasNext() {
                return this.cur < this.to;
            }

            @Override
            public DataPoint next() {
                return DataSet.this.getDataPoint(this.cur++);
            }

            @Override
            public void remove() {
                throw new UnsupportedOperationException("This operation is not supported for DataSet");
            }
        };
        return iteData;
    }

    public abstract int getSampleSize();

    public int getNumCategoricalVars() {
        return this.categories.length;
    }

    public int getNumNumericalVars() {
        return this.numNumerVals;
    }

    public CategoricalData[] getCategories() {
        return this.categories;
    }

    protected abstract Type getSubset(List<Integer> var1);

    public Type getMissingDropped() {
        IntList hasNoMissing = new IntList();
        for (int i = 0; i < this.getSampleSize(); ++i) {
            DataPoint dp = this.getDataPoint(i);
            boolean missing = dp.getNumericalValues().countNaNs() > 0;
            for (int c : dp.getCategoricalValues()) {
                if (c >= 0) continue;
                missing = true;
            }
            if (missing) continue;
            hasNoMissing.add(Integer.valueOf(i));
        }
        return this.getSubset(hasNoMissing);
    }

    public List<Type> randomSplit(Random rand, double ... splits) {
        if (splits.length < 1) {
            throw new IllegalArgumentException("Input array of split fractions must be non-empty");
        }
        IntList randOrder = new IntList(this.getSampleSize());
        ListUtils.addRange(randOrder, 0, this.getSampleSize(), 1);
        Collections.shuffle(randOrder, rand);
        int[] stops = new int[splits.length];
        double sum = 0.0;
        for (int i = 0; i < splits.length; ++i) {
            if ((sum += splits[i]) >= 1.001) {
                throw new IllegalArgumentException("Input splits sum is greater than 1 by index " + i + " reaching a sum of " + sum);
            }
            stops[i] = (int)Math.round(sum * (double)randOrder.size());
        }
        ArrayList<Type> datasets = new ArrayList<Type>(splits.length);
        int prev = 0;
        for (int i = 0; i < stops.length; ++i) {
            datasets.add(this.getSubset(randOrder.subList(prev, stops[i])));
            prev = stops[i];
        }
        return datasets;
    }

    public List<Type> randomSplit(double ... splits) {
        return this.randomSplit(new XORWOW(), splits);
    }

    public List<Type> cvSet(int folds, Random rand) {
        double[] splits = new double[folds];
        Arrays.fill(splits, 1.0 / (double)folds);
        return this.randomSplit(rand, splits);
    }

    public List<Type> cvSet(int folds) {
        return this.cvSet(folds, new XORWOW());
    }

    public List<DataPoint> getDataPoints() {
        ArrayList<DataPoint> list = new ArrayList<DataPoint>(this.getSampleSize());
        for (int i = 0; i < this.getSampleSize(); ++i) {
            list.add(this.getDataPoint(i));
        }
        return list;
    }

    public List<Vec> getDataVectors() {
        ArrayList<Vec> vecs = new ArrayList<Vec>(this.getSampleSize());
        for (int i = 0; i < this.getSampleSize(); ++i) {
            vecs.add(this.getDataPoint(i).getNumericalValues());
        }
        return vecs;
    }

    public Vec getNumericColumn(int i) {
        Vec v;
        if (i < 0 || i >= this.getNumNumericalVars()) {
            throw new IndexOutOfBoundsException("There is no index for column " + i);
        }
        SoftReference<Vec> cachedRef = this.columnVecCache.get(i);
        if (cachedRef != null && (v = cachedRef.get()) != null) {
            return v;
        }
        DenseVector dv = new DenseVector(this.getSampleSize());
        for (int j = 0; j < this.getSampleSize(); ++j) {
            dv.set(j, this.getDataPoint(j).getNumericalValues().get(i));
        }
        Vec toRet = this.getSparsityStats().getMean() < 0.6 ? new SparseVector(dv) : dv;
        this.columnVecCache.put(i, new SoftReference<DenseVector>((DenseVector)toRet));
        return toRet;
    }

    public long countMissingValues() {
        long missing = 0L;
        for (int i = 0; i < this.getSampleSize(); ++i) {
            DataPoint dp = this.getDataPoint(i);
            missing += (long)dp.getNumericalValues().countNaNs();
            for (int c : dp.getCategoricalValues()) {
                if (c >= 0) continue;
                ++missing;
            }
        }
        return missing;
    }

    public Vec[] getNumericColumns() {
        return this.getNumericColumns(Collections.EMPTY_SET);
    }

    public Vec[] getNumericColumns(Set<Integer> skipColumns) {
        int i;
        boolean sparse = this.getSparsityStats().getMean() < 0.6;
        Vec[] cols = new Vec[this.getNumNumericalVars()];
        boolean[] dontSet = new boolean[cols.length];
        Arrays.fill(dontSet, false);
        for (i = 0; i < cols.length; ++i) {
            if (skipColumns.contains(i)) continue;
            SoftReference<Vec> cachedRef = this.columnVecCache.get(i);
            if (cachedRef != null) {
                Vec v = cachedRef.get();
                if (v != null) {
                    cols[i] = v;
                    dontSet[i] = true;
                    continue;
                }
                cols[i] = sparse ? new SparseVector(this.getSampleSize()) : new DenseVector(this.getSampleSize());
                this.columnVecCache.put(i, new SoftReference<SparseVector>((SparseVector)cols[i]));
                continue;
            }
            cols[i] = sparse ? new SparseVector(this.getSampleSize()) : new DenseVector(this.getSampleSize());
            this.columnVecCache.put(i, new SoftReference<SparseVector>((SparseVector)cols[i]));
        }
        for (i = 0; i < this.getSampleSize(); ++i) {
            Vec v = this.getDataPoint(i).getNumericalValues();
            for (IndexValue iv : v) {
                int col = iv.getIndex();
                if (cols[col] == null || dontSet[col]) continue;
                cols[col].set(i, iv.getValue());
            }
        }
        return cols;
    }

    public Matrix getDataMatrix() {
        DenseMatrix matrix = new DenseMatrix(this.getSampleSize(), this.getNumNumericalVars());
        for (int i = 0; i < this.getSampleSize(); ++i) {
            Vec row = this.getDataPoint(i).getNumericalValues();
            for (int j = 0; j < row.length(); ++j) {
                matrix.set(i, j, row.get(j));
            }
        }
        return matrix;
    }

    public Matrix getDataMatrixView() {
        return new MatrixOfVecs(this.getDataVectors());
    }

    public int getNumFeatures() {
        return this.getNumCategoricalVars() + this.getNumNumericalVars();
    }

    public abstract DataSet<Type> shallowClone();

    public DataSet getTwiceShallowClone() {
        DataSet<Type> clone = this.shallowClone();
        for (int i = 0; i < clone.getSampleSize(); ++i) {
            DataPoint d = this.getDataPoint(i);
            DataPoint sd = new DataPoint(d.getNumericalValues(), d.getCategoricalValues(), d.getCategoricalData());
            clone.setDataPoint(i, sd);
        }
        return clone;
    }

    public OnLineStatistics getSparsityStats() {
        OnLineStatistics stats = new OnLineStatistics();
        for (int i = 0; i < this.getSampleSize(); ++i) {
            Vec v = this.getDataPoint(i).getNumericalValues();
            if (v.isSparse()) {
                stats.add((double)v.nnz() / (double)v.length());
                continue;
            }
            stats.add(1.0);
        }
        return stats;
    }

    public Vec getDataWeights() {
        int N = this.getSampleSize();
        if (N == 0) {
            return new DenseVector(0);
        }
        double weight = this.getDataPoint(0).getWeight();
        double[] weights = null;
        for (int i = 1; i < N; ++i) {
            double w_i = this.getDataPoint(i).getWeight();
            if (weights == null && weight == w_i) continue;
            if (weights == null) {
                weights = new double[N];
                Arrays.fill(weights, 0, i, weight);
            }
            weights[i] = w_i;
        }
        if (weights == null) {
            return new ConstantVector(weight, this.getSampleSize());
        }
        return new DenseVector(weights);
    }
}

