/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.trees;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import java.util.Set;
import java.util.Stack;
import java.util.concurrent.ExecutorService;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.DataPointPair;
import jsat.classifiers.trees.ImpurityScore;
import jsat.classifiers.trees.TreeLearner;
import jsat.classifiers.trees.TreeNodeVisitor;
import jsat.math.OnLineStatistics;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.regression.RegressionDataSet;
import jsat.regression.Regressor;
import jsat.utils.IntList;
import jsat.utils.IntSet;
import jsat.utils.ListUtils;
import jsat.utils.random.RandomUtil;

public class ExtraTree
implements Classifier,
Regressor,
TreeLearner,
Parameterized {
    private static final long serialVersionUID = 7433728970041876327L;
    private int stopSize;
    private int selectionCount;
    private CategoricalData predicting;
    private boolean binaryCategoricalSplitting = true;
    private int numNumericFeatures;
    private ImpurityScore.ImpurityMeasure impMeasure = ImpurityScore.ImpurityMeasure.NMI;
    private TreeNodeVisitor root;

    public ExtraTree() {
        this(Integer.MAX_VALUE, 5);
    }

    public ExtraTree(int selectionCount, int stopSize) {
        this.stopSize = stopSize;
        this.selectionCount = selectionCount;
        this.impMeasure = ImpurityScore.ImpurityMeasure.NMI;
    }

    public void setImpurityMeasure(ImpurityScore.ImpurityMeasure impurityMeasure) {
        this.impMeasure = impurityMeasure;
    }

    public ImpurityScore.ImpurityMeasure getImpurityMeasure() {
        return this.impMeasure;
    }

    public void setStopSize(int stopSize) {
        if (stopSize <= 0) {
            throw new ArithmeticException("The stopping size must be a positive value");
        }
        this.stopSize = stopSize;
    }

    public int getStopSize() {
        return this.stopSize;
    }

    public void setSelectionCount(int selectionCount) {
        this.selectionCount = selectionCount;
    }

    public int getSelectionCount() {
        return this.selectionCount;
    }

    public void setBinaryCategoricalSplitting(boolean binaryCategoricalSplitting) {
        this.binaryCategoricalSplitting = binaryCategoricalSplitting;
    }

    public boolean isBinaryCategoricalSplitting() {
        return this.binaryCategoricalSplitting;
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        return this.root.classify(data);
    }

    @Override
    public void trainC(ClassificationDataSet dataSet, ExecutorService threadPool) {
        this.trainC(dataSet);
    }

    @Override
    public void trainC(ClassificationDataSet dataSet) {
        Random rand = RandomUtil.getRandom();
        Stack<List<DataPointPair<Integer>>> reusableLists = new Stack<List<DataPointPair<Integer>>>();
        IntList features = new IntList(dataSet.getNumFeatures());
        ListUtils.addRange(features, 0, dataSet.getNumFeatures(), 1);
        List<DataPointPair<Integer>> data = dataSet.getAsDPPList();
        this.predicting = dataSet.getPredicting();
        ImpurityScore score = new ImpurityScore(this.predicting.getNumOfCategories(), this.impMeasure);
        for (DataPointPair<Integer> dpp : data) {
            score.addPoint(dpp.getDataPoint(), (int)dpp.getPair());
        }
        this.numNumericFeatures = dataSet.getNumNumericalVars();
        this.root = this.trainC(score, data, features, dataSet.getCategories(), rand, reusableLists);
    }

    private TreeNodeVisitor trainC(ImpurityScore setScore, List<DataPointPair<Integer>> subSet, List<Integer> features, CategoricalData[] catInfo, Random rand, Stack<List<DataPointPair<Integer>>> reusableLists) {
        NodeC toReturn;
        if (subSet.size() < this.stopSize || setScore.getScore() == 0.0) {
            if (subSet.isEmpty()) {
                return null;
            }
            return new NodeC(setScore.getResults());
        }
        double bestGain = Double.NEGATIVE_INFINITY;
        double bestThreshold = Double.NaN;
        int bestAttribute = -1;
        ImpurityScore[] bestScores = null;
        ArrayList<List<Object>> bestSplit = null;
        IntSet bestLeftSide = null;
        Collections.shuffle(features);
        int goTo = Math.min(this.selectionCount, features.size());
        for (int i = 0; i < goTo; ++i) {
            double gain;
            ArrayList<List<Object>> aSplit;
            ImpurityScore[] scores;
            double threshold = Double.NaN;
            IntSet leftSide = null;
            int a = features.get(i);
            if (a < catInfo.length) {
                int vals = catInfo[a].getNumOfCategories();
                if (this.binaryCategoricalSplitting || vals == 2) {
                    scores = this.createScores(2);
                    IntSet catsValsInUse = new IntSet(vals * 2);
                    for (DataPointPair<Integer> dataPointPair : subSet) {
                        catsValsInUse.add(Integer.valueOf(dataPointPair.getDataPoint().getCategoricalValue(a)));
                    }
                    if (catsValsInUse.size() == 1) {
                        return new NodeC(setScore.getResults());
                    }
                    leftSide = new IntSet(vals);
                    int n = rand.nextInt(catsValsInUse.size() - 1) + 1;
                    ListUtils.randomSample(catsValsInUse, leftSide, n, rand);
                    aSplit = new ArrayList(2);
                    ExtraTree.fillList(2, reusableLists, aSplit);
                    for (DataPointPair<Integer> dpp3 : subSet) {
                        DataPoint dp = dpp3.getDataPoint();
                        int dest = leftSide.contains((Object)dpp3.getDataPoint().getCategoricalValue(a)) ? 0 : 1;
                        scores[dest].addPoint(dp, (int)dpp3.getPair());
                        ((List)aSplit.get(dest)).add(dpp3);
                    }
                } else {
                    scores = this.createScores(vals);
                    aSplit = new ArrayList(vals);
                    ExtraTree.fillList(vals, reusableLists, aSplit);
                    for (DataPointPair dataPointPair : subSet) {
                        DataPoint dataPoint = dataPointPair.getDataPoint();
                        scores[dataPoint.getCategoricalValue(a)].addPoint(dataPoint, (int)((Integer)dataPointPair.getPair()));
                        ((List)aSplit.get(dataPoint.getCategoricalValue(a))).add(dataPointPair);
                    }
                }
            } else {
                double val;
                int numerA = a - catInfo.length;
                double min = Double.POSITIVE_INFINITY;
                double d = Double.NEGATIVE_INFINITY;
                for (DataPointPair<Integer> dpp : subSet) {
                    val = dpp.getVector().get(numerA);
                    min = Math.min(min, val);
                    d = Math.max(d, val);
                }
                threshold = rand.nextDouble() * (d - min) + min;
                scores = this.createScores(2);
                aSplit = new ArrayList(2);
                ExtraTree.fillList(2, reusableLists, aSplit);
                for (DataPointPair<Integer> dpp : subSet) {
                    val = dpp.getVector().get(numerA);
                    int toAddTo = val <= threshold ? 0 : 1;
                    ((List)aSplit.get(toAddTo)).add(dpp);
                    scores[toAddTo].addPoint(dpp.getDataPoint(), (int)dpp.getPair());
                }
            }
            if ((gain = ImpurityScore.gain(setScore, scores)) > bestGain) {
                bestGain = gain;
                bestAttribute = a;
                bestThreshold = threshold;
                bestScores = scores;
                if (bestSplit != null) {
                    ExtraTree.fillStack(reusableLists, bestSplit);
                }
                bestSplit = aSplit;
                bestLeftSide = leftSide;
                continue;
            }
            ExtraTree.fillStack(reusableLists, aSplit);
        }
        ExtraTree.fillStack(reusableLists, Arrays.asList(subSet));
        if (bestAttribute < catInfo.length) {
            if (bestSplit.size() == 2) {
                toReturn = new NodeCCat(bestAttribute, bestLeftSide, setScore.getResults());
            } else {
                toReturn = new NodeCCat(goTo, bestSplit.size(), setScore.getResults());
                features.remove(new Integer(bestAttribute));
            }
        } else {
            toReturn = new NodeCNum(bestAttribute - catInfo.length, bestThreshold, setScore.getResults());
        }
        for (int i = 0; i < toReturn.children.length; ++i) {
            toReturn.children[i] = this.trainC((ImpurityScore)bestScores[i], (List)bestSplit.get(i), features, catInfo, rand, reusableLists);
        }
        return toReturn;
    }

    private TreeNodeVisitor train(OnLineStatistics setScore, List<DataPointPair<Double>> subSet, List<Integer> features, CategoricalData[] catInfo, Random rand, Stack<List<DataPointPair<Double>>> reusableLists) {
        if (subSet.size() < this.stopSize || setScore.getVarance() <= 0.0 || Double.isNaN(setScore.getVarance())) {
            return new NodeR(setScore.getMean());
        }
        double bestGain = Double.NEGATIVE_INFINITY;
        double bestThreshold = Double.NaN;
        int bestAttribute = -1;
        OnLineStatistics[] bestScores = null;
        ArrayList bestSplit = null;
        IntSet bestLeftSide = null;
        Collections.shuffle(features);
        int goTo = Math.min(this.selectionCount, features.size());
        for (int i = 0; i < goTo; ++i) {
            ArrayList aSplit;
            OnLineStatistics[] stats;
            double threshold = Double.NaN;
            IntSet leftSide = null;
            int a = features.get(i);
            if (a < catInfo.length) {
                int vals = catInfo[a].getNumOfCategories();
                if (this.binaryCategoricalSplitting || vals == 2) {
                    stats = this.createStats(2);
                    IntSet catsValsInUse = new IntSet(vals * 2);
                    for (DataPointPair<Double> dataPointPair : subSet) {
                        catsValsInUse.add(Integer.valueOf(dataPointPair.getDataPoint().getCategoricalValue(a)));
                    }
                    if (catsValsInUse.size() == 1) {
                        return new NodeR(setScore.getMean());
                    }
                    leftSide = new IntSet(vals);
                    int n = rand.nextInt(catsValsInUse.size() - 1) + 1;
                    ListUtils.randomSample(catsValsInUse, leftSide, n, rand);
                    aSplit = new ArrayList(2);
                    ExtraTree.fillList(2, reusableLists, aSplit);
                    for (DataPointPair<Double> dpp3 : subSet) {
                        DataPoint dp = dpp3.getDataPoint();
                        int dest = leftSide.contains((Object)dpp3.getDataPoint().getCategoricalValue(a)) ? 0 : 1;
                        stats[dest].add(dpp3.getPair(), dp.getWeight());
                        ((List)aSplit.get(dest)).add(dpp3);
                    }
                } else {
                    stats = this.createStats(vals);
                    aSplit = new ArrayList(vals);
                    ExtraTree.fillList(vals, reusableLists, aSplit);
                    for (DataPointPair dataPointPair : subSet) {
                        DataPoint dataPoint = dataPointPair.getDataPoint();
                        stats[dataPoint.getCategoricalValue(a)].add((Double)dataPointPair.getPair(), dataPoint.getWeight());
                        ((List)aSplit.get(dataPoint.getCategoricalValue(a))).add(dataPointPair);
                    }
                }
            } else {
                double val;
                int numerA = a - catInfo.length;
                double min = Double.POSITIVE_INFINITY;
                double d = Double.NEGATIVE_INFINITY;
                for (DataPointPair<Double> dpp : subSet) {
                    val = dpp.getVector().get(numerA);
                    min = Math.min(min, val);
                    d = Math.max(d, val);
                }
                threshold = rand.nextDouble() * (d - min) + min;
                stats = this.createStats(2);
                aSplit = new ArrayList(2);
                ExtraTree.fillList(2, reusableLists, aSplit);
                for (DataPointPair<Double> dpp : subSet) {
                    val = dpp.getVector().get(numerA);
                    int toAddTo = val <= threshold ? 0 : 1;
                    ((List)aSplit.get(toAddTo)).add(dpp);
                    stats[toAddTo].add(dpp.getPair(), dpp.getDataPoint().getWeight());
                }
            }
            double gain = 1.0;
            double varNorm = setScore.getVarance();
            double d = setScore.getSumOfWeights();
            for (OnLineStatistics stat : stats) {
                gain -= stat.getSumOfWeights() / d * (stat.getVarance() / varNorm);
            }
            if (gain > bestGain) {
                bestGain = gain;
                bestAttribute = a;
                bestThreshold = threshold;
                bestScores = stats;
                if (bestSplit != null) {
                    ExtraTree.fillStack(reusableLists, bestSplit);
                }
                bestSplit = aSplit;
                bestLeftSide = leftSide;
                continue;
            }
            ExtraTree.fillStack(reusableLists, aSplit);
        }
        ExtraTree.fillStack(reusableLists, Arrays.asList(subSet));
        if (bestAttribute >= 0) {
            NodeR toReturn;
            if (bestAttribute < catInfo.length) {
                if (bestSplit.size() == 2) {
                    toReturn = new NodeRCat(bestAttribute, bestLeftSide, setScore.getMean());
                } else {
                    toReturn = new NodeRCat(goTo, bestSplit.size(), setScore.getMean());
                    features.remove(new Integer(bestAttribute));
                }
            } else {
                toReturn = new NodeRNum(bestAttribute - catInfo.length, bestThreshold, setScore.getMean());
            }
            for (int i = 0; i < toReturn.children.length; ++i) {
                toReturn.children[i] = this.train((OnLineStatistics)bestScores[i], (List)bestSplit.get(i), features, catInfo, rand, reusableLists);
            }
            return toReturn;
        }
        return new NodeR(setScore.getMean());
    }

    @Override
    public boolean supportsWeightedData() {
        return true;
    }

    @Override
    public ExtraTree clone() {
        ExtraTree clone = new ExtraTree(this.selectionCount, this.stopSize);
        clone.impMeasure = this.impMeasure;
        clone.binaryCategoricalSplitting = this.binaryCategoricalSplitting;
        if (this.predicting != null) {
            clone.predicting = this.predicting.clone();
        }
        if (this.root != null) {
            clone.root = this.root.clone();
        }
        clone.numNumericFeatures = this.numNumericFeatures;
        return clone;
    }

    @Override
    public TreeNodeVisitor getTreeNodeVisitor() {
        return this.root;
    }

    private static <T> void fillList(int listsToAdd, Stack<List<T>> reusableLists, List<List<T>> aSplit) {
        for (int j = 0; j < listsToAdd; ++j) {
            if (reusableLists.isEmpty()) {
                aSplit.add(new ArrayList());
                continue;
            }
            aSplit.add(reusableLists.pop());
        }
    }

    private static <T> void fillStack(Stack<List<T>> reusableLists, List<List<T>> aSplit) {
        for (List<T> list : aSplit) {
            list.clear();
            reusableLists.push(list);
        }
    }

    private ImpurityScore[] createScores(int count) {
        ImpurityScore[] scores = new ImpurityScore[count];
        for (int j = 0; j < scores.length; ++j) {
            scores[j] = new ImpurityScore(this.predicting.getNumOfCategories(), this.impMeasure);
        }
        return scores;
    }

    @Override
    public double regress(DataPoint data) {
        return this.root.regress(data);
    }

    @Override
    public void train(RegressionDataSet dataSet, ExecutorService threadPool) {
        this.train(dataSet);
    }

    @Override
    public void train(RegressionDataSet dataSet) {
        Random rand = RandomUtil.getRandom();
        Stack<List<DataPointPair<Double>>> reusableLists = new Stack<List<DataPointPair<Double>>>();
        IntList features = new IntList(dataSet.getNumFeatures());
        ListUtils.addRange(features, 0, dataSet.getNumFeatures(), 1);
        List<DataPointPair<Double>> data = dataSet.getAsDPPList();
        OnLineStatistics score = new OnLineStatistics();
        for (DataPointPair<Double> dpp : data) {
            score.add(dpp.getPair(), dpp.getDataPoint().getWeight());
        }
        this.numNumericFeatures = dataSet.getNumNumericalVars();
        this.root = this.train(score, data, features, dataSet.getCategories(), rand, reusableLists);
    }

    private OnLineStatistics[] createStats(int count) {
        OnLineStatistics[] stats = new OnLineStatistics[count];
        for (int i = 0; i < stats.length; ++i) {
            stats[i] = new OnLineStatistics();
        }
        return stats;
    }

    @Override
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override
    public Parameter getParameter(String paramName) {
        return Parameter.toParameterMap(this.getParameters()).get(paramName);
    }

    private class NodeRCat
    extends NodeR {
        private static final long serialVersionUID = 5868393594474661054L;
        private int catAtt;
        private int[] leftBranch;

        public NodeRCat(int catAtt, int children, double result) {
            super(result, children);
            this.catAtt = catAtt;
            this.leftBranch = null;
        }

        public NodeRCat(int catAtt, Set<Integer> left, double result) {
            super(result, 2);
            this.catAtt = catAtt;
            this.leftBranch = new int[left.size()];
            int pos = 0;
            for (int i : left) {
                this.leftBranch[pos++] = i;
            }
            Arrays.sort(this.leftBranch);
        }

        public NodeRCat(NodeRCat toClone) {
            super(toClone);
            this.catAtt = toClone.catAtt;
            if (toClone.leftBranch != null) {
                this.leftBranch = Arrays.copyOf(toClone.leftBranch, toClone.leftBranch.length);
            }
        }

        @Override
        public int getPath(DataPoint dp) {
            int[] catVals = dp.getCategoricalValues();
            if (this.leftBranch == null) {
                return catVals[this.catAtt];
            }
            if (Arrays.binarySearch(this.leftBranch, catVals[this.catAtt]) < 0) {
                return 1;
            }
            return 0;
        }

        @Override
        public Collection<Integer> featuresUsed() {
            IntList used = new IntList(1);
            used.add(this.catAtt + ExtraTree.this.numNumericFeatures);
            return used;
        }

        @Override
        public TreeNodeVisitor clone() {
            return new NodeRCat(this);
        }
    }

    private static class NodeRNum
    extends NodeR {
        private static final long serialVersionUID = -6775472771777960211L;
        private int numerAtt;
        private double threshold;

        public NodeRNum(int numerAtt, double threshold, double result) {
            super(result, 2);
            this.numerAtt = numerAtt;
            this.threshold = threshold;
        }

        public NodeRNum(NodeRNum toClone) {
            super(toClone);
            this.numerAtt = toClone.numerAtt;
            this.threshold = toClone.threshold;
        }

        @Override
        public int getPath(DataPoint dp) {
            double val = dp.getNumericalValues().get(this.numerAtt);
            if (val <= this.threshold) {
                return 0;
            }
            return 1;
        }

        @Override
        public TreeNodeVisitor clone() {
            return new NodeRNum(this);
        }

        @Override
        public Collection<Integer> featuresUsed() {
            IntList used = new IntList(1);
            used.add(this.numerAtt);
            return used;
        }
    }

    private static class NodeR
    extends NodeBase {
        private static final long serialVersionUID = -2461046505444129890L;
        private double result;

        public NodeR(double result) {
            this.result = result;
        }

        public NodeR(double result, int children) {
            super(children);
            this.result = result;
        }

        public NodeR(NodeR toClone) {
            super(toClone);
            this.result = toClone.result;
        }

        @Override
        public double localRegress(DataPoint dp) {
            return this.result;
        }

        @Override
        public int getPath(DataPoint dp) {
            return -1;
        }

        @Override
        public TreeNodeVisitor clone() {
            return new NodeR(this);
        }

        @Override
        public Collection<Integer> featuresUsed() {
            return Collections.EMPTY_SET;
        }
    }

    private static abstract class NodeBase
    extends TreeNodeVisitor {
        private static final long serialVersionUID = 6783491817922690901L;
        protected TreeNodeVisitor[] children;

        public NodeBase() {
        }

        public NodeBase(int children) {
            this.children = new TreeNodeVisitor[children];
        }

        public NodeBase(NodeBase toClone) {
            if (toClone.children != null) {
                this.children = new TreeNodeVisitor[toClone.children.length];
                for (int i = 0; i < toClone.children.length; ++i) {
                    if (toClone.children[i] == null) continue;
                    this.children[i] = toClone.children[i].clone();
                }
            }
        }

        @Override
        public int childrenCount() {
            return this.children.length;
        }

        @Override
        public boolean isLeaf() {
            if (this.children == null) {
                return true;
            }
            for (int i = 0; i < this.children.length; ++i) {
                if (this.children[i] == null) continue;
                return false;
            }
            return true;
        }

        @Override
        public TreeNodeVisitor getChild(int child) {
            if (child < 0 || child > this.childrenCount()) {
                return null;
            }
            return this.children[child];
        }

        @Override
        public void disablePath(int child) {
            if (!this.isLeaf()) {
                this.children[child] = null;
            }
        }

        @Override
        public boolean isPathDisabled(int child) {
            if (this.isLeaf()) {
                return true;
            }
            return this.children[child] == null;
        }
    }

    private static class NodeC
    extends NodeBase {
        private static final long serialVersionUID = -3977497656918695759L;
        private CategoricalResults crResult;

        public NodeC(CategoricalResults crResult) {
            this.crResult = crResult;
            this.children = null;
        }

        public NodeC(CategoricalResults crResult, int children) {
            super(children);
            this.crResult = crResult;
        }

        public NodeC(NodeC toClone) {
            super(toClone);
            this.crResult = toClone.crResult.clone();
        }

        @Override
        public CategoricalResults localClassify(DataPoint dp) {
            return this.crResult;
        }

        @Override
        public int getPath(DataPoint dp) {
            return -1;
        }

        @Override
        public TreeNodeVisitor clone() {
            return new NodeC(this);
        }

        @Override
        public Collection<Integer> featuresUsed() {
            return Collections.EMPTY_SET;
        }
    }

    private static class NodeCNum
    extends NodeC {
        private static final long serialVersionUID = 3967180517059509869L;
        private int numerAtt;
        private double threshold;

        public NodeCNum(int numerAtt, double threshold, CategoricalResults crResult) {
            super(crResult, 2);
            this.numerAtt = numerAtt;
            this.threshold = threshold;
        }

        public NodeCNum(NodeCNum toClone) {
            super(toClone);
            this.numerAtt = toClone.numerAtt;
            this.threshold = toClone.threshold;
        }

        @Override
        public int getPath(DataPoint dp) {
            double val = dp.getNumericalValues().get(this.numerAtt);
            if (val <= this.threshold) {
                return 0;
            }
            return 1;
        }

        @Override
        public TreeNodeVisitor clone() {
            return new NodeCNum(this);
        }

        @Override
        public Collection<Integer> featuresUsed() {
            IntList used = new IntList(1);
            used.add(this.numerAtt);
            return used;
        }
    }

    private class NodeCCat
    extends NodeC {
        private static final long serialVersionUID = 7413428280703235600L;
        private int catAtt;
        private int[] leftBranch;

        public NodeCCat(int catAtt, int children, CategoricalResults crResult) {
            super(crResult, children);
            this.catAtt = catAtt;
            this.leftBranch = null;
        }

        public NodeCCat(int catAtt, Set<Integer> left, CategoricalResults crResult) {
            super(crResult, 2);
            this.catAtt = catAtt;
            this.leftBranch = new int[left.size()];
            int pos = 0;
            for (int i : left) {
                this.leftBranch[pos++] = i;
            }
            Arrays.sort(this.leftBranch);
        }

        public NodeCCat(NodeCCat toClone) {
            super(toClone);
            this.catAtt = toClone.catAtt;
            if (toClone.leftBranch != null) {
                this.leftBranch = Arrays.copyOf(toClone.leftBranch, toClone.leftBranch.length);
            }
        }

        @Override
        public int getPath(DataPoint dp) {
            int[] catVals = dp.getCategoricalValues();
            if (this.leftBranch == null) {
                return catVals[this.catAtt];
            }
            if (Arrays.binarySearch(this.leftBranch, catVals[this.catAtt]) < 0) {
                return 1;
            }
            return 0;
        }

        @Override
        public TreeNodeVisitor clone() {
            return new NodeCCat(this);
        }

        @Override
        public Collection<Integer> featuresUsed() {
            IntList used = new IntList(1);
            used.add(this.catAtt + ExtraTree.this.numNumericFeatures);
            return used;
        }
    }
}

