/*
 * Decompiled with CFR 0.152.
 */
package jsat.parameters;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.PriorityQueue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.DataSet;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.ClassificationModelEvaluation;
import jsat.classifiers.Classifier;
import jsat.classifiers.WarmClassifier;
import jsat.distributions.Distribution;
import jsat.parameters.DoubleParameter;
import jsat.parameters.IntParameter;
import jsat.parameters.ModelSearch;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.regression.RegressionDataSet;
import jsat.regression.RegressionModelEvaluation;
import jsat.regression.Regressor;
import jsat.regression.WarmRegressor;
import jsat.utils.DoubleList;
import jsat.utils.FakeExecutor;

public class GridSearch
extends ModelSearch {
    private static final long serialVersionUID = -1987196172499143753L;
    private List<List<Double>> searchValues;
    private boolean useWarmStarts = true;

    public GridSearch(Regressor baseRegressor, int folds) {
        super(baseRegressor, folds);
        this.searchValues = new ArrayList<List<Double>>();
    }

    public GridSearch(Classifier baseClassifier, int folds) {
        super(baseClassifier, folds);
        this.searchValues = new ArrayList<List<Double>>();
    }

    public GridSearch(GridSearch toCopy) {
        super(toCopy);
        this.useWarmStarts = toCopy.useWarmStarts;
        if (toCopy.searchValues != null) {
            this.searchValues = new ArrayList<List<Double>>();
            for (List<Double> ld : toCopy.searchValues) {
                DoubleList newVals = new DoubleList(ld);
                this.searchValues.add(newVals);
            }
        }
    }

    public int autoAddParameters(DataSet data) {
        return this.autoAddParameters(data, 10);
    }

    public int autoAddParameters(DataSet data, int paramsEach) {
        Parameterized obj = this.baseClassifier != null ? (Parameterized)((Object)this.baseClassifier) : (Parameterized)((Object)this.baseRegressor);
        int totalParms = 0;
        for (Parameter param : obj.getParameters()) {
            Distribution dist;
            if (param instanceof DoubleParameter) {
                dist = ((DoubleParameter)param).getGuess(data);
                if (dist == null) continue;
                ++totalParms;
                continue;
            }
            if (!(param instanceof IntParameter) || (dist = ((IntParameter)param).getGuess(data)) == null) continue;
            ++totalParms;
        }
        if (totalParms < 1) {
            return 0;
        }
        double[] quantiles = new double[paramsEach];
        for (int i = 0; i < quantiles.length; ++i) {
            quantiles[i] = ((double)i + 1.0) / ((double)paramsEach + 1.0);
        }
        for (Parameter param : obj.getParameters()) {
            int i;
            Object[] vals;
            Distribution dist;
            if (param instanceof DoubleParameter) {
                dist = ((DoubleParameter)param).getGuess(data);
                if (dist == null) continue;
                vals = new double[paramsEach];
                for (i = 0; i < vals.length; ++i) {
                    vals[i] = dist.invCdf(quantiles[i]);
                }
                this.addParameter((DoubleParameter)param, (double[])vals);
                continue;
            }
            if (!(param instanceof IntParameter) || (dist = ((IntParameter)param).getGuess(data)) == null) continue;
            vals = new int[paramsEach];
            for (i = 0; i < vals.length; ++i) {
                vals[i] = (int)Math.round(dist.invCdf(quantiles[i]));
            }
            this.addParameter((IntParameter)param, (int[])vals);
        }
        return totalParms;
    }

    public void setUseWarmStarts(boolean useWarmStarts) {
        this.useWarmStarts = useWarmStarts;
    }

    public boolean isUseWarmStarts() {
        return this.useWarmStarts;
    }

    public void addParameter(DoubleParameter param, double ... initialSearchValues) {
        if (param == null) {
            throw new IllegalArgumentException("null not allowed for parameter");
        }
        this.searchParams.add(param);
        DoubleList dl = new DoubleList(initialSearchValues.length);
        for (double d : initialSearchValues) {
            dl.add(d);
        }
        Arrays.sort(dl.getBackingArray());
        if (param.isWarmParameter() && !param.preferredLowToHigh()) {
            Collections.reverse(dl);
        }
        if (param.isWarmParameter()) {
            this.searchValues.add(0, dl);
        } else {
            this.searchValues.add(dl);
        }
    }

    public void addParameter(String name, double ... initialSearchValues) {
        Parameter param = this.getParameterByName(name);
        if (!(param instanceof DoubleParameter)) {
            throw new IllegalArgumentException("Parameter " + name + " is not for double values");
        }
        this.addParameter((DoubleParameter)param, initialSearchValues);
    }

    public void addParameter(IntParameter param, int ... initialSearchValues) {
        this.searchParams.add(param);
        DoubleList dl = new DoubleList(initialSearchValues.length);
        int[] nArray = initialSearchValues;
        int n = nArray.length;
        for (int i = 0; i < n; ++i) {
            double d = nArray[i];
            dl.add(d);
        }
        Arrays.sort(dl.getBackingArray());
        if (param.isWarmParameter() && !param.preferredLowToHigh()) {
            Collections.reverse(dl);
        }
        if (param.isWarmParameter()) {
            this.searchValues.add(0, dl);
        } else {
            this.searchValues.add(dl);
        }
    }

    public void addParameter(String name, int ... initialSearchValues) {
        Parameter param = this.getParameterByName(name);
        if (!(param instanceof IntParameter)) {
            throw new IllegalArgumentException("Parameter " + name + " is not for int values");
        }
        this.addParameter((IntParameter)param, initialSearchValues);
    }

    @Override
    public void train(final RegressionDataSet dataSet, final ExecutorService threadPool) {
        CountDownLatch latch;
        boolean considerWarm;
        ArrayList<RegressionDataSet> trainCombinations;
        List<RegressionDataSet> preFolded;
        final PriorityQueue<RegressionModelEvaluation> bestModels = new PriorityQueue<RegressionModelEvaluation>(this.folds, new Comparator<RegressionModelEvaluation>(){

            @Override
            public int compare(RegressionModelEvaluation t, RegressionModelEvaluation t1) {
                double v0 = t.getScoreStats(GridSearch.this.regressionTargetScore).getMean();
                double v1 = t1.getScoreStats(GridSearch.this.regressionTargetScore).getMean();
                int order = GridSearch.this.regressionTargetScore.lowerIsBetter() ? 1 : -1;
                return order * Double.compare(v0, v1);
            }
        });
        int[] setTo = new int[this.searchParams.size()];
        ArrayList<Regressor> paramsToEval = new ArrayList<Regressor>();
        do {
            this.setParameters(setTo);
            paramsToEval.add(this.baseRegressor.clone());
        } while (!this.incrementCombination(setTo));
        ExecutorService modelService = this.trainModelsInParallel ? threadPool : new FakeExecutor();
        if (this.reuseSameCVFolds) {
            preFolded = dataSet.cvSet(this.folds);
            trainCombinations = new ArrayList<RegressionDataSet>(preFolded.size());
            for (int i = 0; i < preFolded.size(); ++i) {
                trainCombinations.add(RegressionDataSet.comineAllBut(preFolded, i));
            }
        } else {
            preFolded = null;
            trainCombinations = null;
        }
        boolean bl = considerWarm = this.useWarmStarts && this.baseRegressor instanceof WarmRegressor;
        if (considerWarm && (!((WarmRegressor)this.baseRegressor).warmFromSameDataOnly() || this.reuseSameCVFolds)) {
            int stepSize = this.searchValues.get(0).size();
            int totalJobs = paramsToEval.size() / stepSize;
            latch = new CountDownLatch(totalJobs);
            for (int startPos = 0; startPos < paramsToEval.size(); startPos += stepSize) {
                final List subSet = paramsToEval.subList(startPos, startPos + stepSize);
                modelService.submit(new Runnable(){

                    /*
                     * WARNING - Removed try catching itself - possible behaviour change.
                     */
                    @Override
                    public void run() {
                        Regressor[] prevModels = null;
                        for (Regressor r : subSet) {
                            RegressionModelEvaluation rme = GridSearch.this.trainModelsInParallel ? new RegressionModelEvaluation(r, dataSet) : new RegressionModelEvaluation(r, dataSet, threadPool);
                            rme.setKeepModels(true);
                            rme.setWarmModels(prevModels);
                            rme.addScorer(GridSearch.this.regressionTargetScore.clone());
                            if (GridSearch.this.reuseSameCVFolds) {
                                rme.evaluateCrossValidation(preFolded, trainCombinations);
                            } else {
                                rme.evaluateCrossValidation(GridSearch.this.folds);
                            }
                            prevModels = rme.getKeptModels();
                            PriorityQueue priorityQueue = bestModels;
                            synchronized (priorityQueue) {
                                bestModels.add(rme);
                            }
                        }
                        latch.countDown();
                    }
                });
            }
        } else {
            latch = new CountDownLatch(paramsToEval.size());
            for (final Regressor toTrain : paramsToEval) {
                modelService.submit(new Runnable(){

                    /*
                     * WARNING - Removed try catching itself - possible behaviour change.
                     */
                    @Override
                    public void run() {
                        RegressionModelEvaluation rme = GridSearch.this.trainModelsInParallel ? new RegressionModelEvaluation(toTrain, dataSet) : new RegressionModelEvaluation(toTrain, dataSet, threadPool);
                        rme.addScorer(GridSearch.this.regressionTargetScore.clone());
                        if (GridSearch.this.reuseSameCVFolds) {
                            rme.evaluateCrossValidation(preFolded, trainCombinations);
                        } else {
                            rme.evaluateCrossValidation(GridSearch.this.folds);
                        }
                        PriorityQueue priorityQueue = bestModels;
                        synchronized (priorityQueue) {
                            bestModels.add(rme);
                        }
                        latch.countDown();
                    }
                });
            }
        }
        try {
            latch.await();
            Regressor bestRegressor = bestModels.peek().getRegressor();
            if (this.trainFinalModel) {
                if (this.useWarmStarts && bestRegressor instanceof WarmRegressor && !((WarmRegressor)bestRegressor).warmFromSameDataOnly()) {
                    WarmRegressor wr = (WarmRegressor)bestRegressor;
                    if (threadPool instanceof FakeExecutor) {
                        wr.train(dataSet, (Regressor)wr.clone());
                    } else {
                        wr.train(dataSet, (Regressor)wr.clone(), threadPool);
                    }
                } else if (threadPool instanceof FakeExecutor) {
                    bestRegressor.train(dataSet);
                } else {
                    bestRegressor.train(dataSet, threadPool);
                }
            }
            this.trainedRegressor = bestRegressor;
        }
        catch (InterruptedException ex) {
            Logger.getLogger(GridSearch.class.getName()).log(Level.SEVERE, null, ex);
        }
    }

    @Override
    public void train(RegressionDataSet dataSet) {
        this.train(dataSet, new FakeExecutor());
    }

    @Override
    public void trainC(final ClassificationDataSet dataSet, final ExecutorService threadPool) {
        CountDownLatch latch;
        boolean considerWarm;
        ArrayList<ClassificationDataSet> trainCombinations;
        List<ClassificationDataSet> preFolded;
        final PriorityQueue<ClassificationModelEvaluation> bestModels = new PriorityQueue<ClassificationModelEvaluation>(this.folds, new Comparator<ClassificationModelEvaluation>(){

            @Override
            public int compare(ClassificationModelEvaluation t, ClassificationModelEvaluation t1) {
                double v0 = t.getScoreStats(GridSearch.this.classificationTargetScore).getMean();
                double v1 = t1.getScoreStats(GridSearch.this.classificationTargetScore).getMean();
                int order = GridSearch.this.classificationTargetScore.lowerIsBetter() ? 1 : -1;
                return order * Double.compare(v0, v1);
            }
        });
        int[] setTo = new int[this.searchParams.size()];
        ArrayList<Classifier> paramsToEval = new ArrayList<Classifier>();
        do {
            this.setParameters(setTo);
            paramsToEval.add(this.baseClassifier.clone());
        } while (!this.incrementCombination(setTo));
        ExecutorService modelService = this.trainModelsInParallel ? threadPool : new FakeExecutor();
        if (this.reuseSameCVFolds) {
            preFolded = dataSet.cvSet(this.folds);
            trainCombinations = new ArrayList<ClassificationDataSet>(preFolded.size());
            for (int i = 0; i < preFolded.size(); ++i) {
                trainCombinations.add(ClassificationDataSet.comineAllBut(preFolded, i));
            }
        } else {
            preFolded = null;
            trainCombinations = null;
        }
        boolean bl = considerWarm = this.useWarmStarts && this.baseClassifier instanceof WarmClassifier;
        if (considerWarm && (!((WarmClassifier)this.baseClassifier).warmFromSameDataOnly() || this.reuseSameCVFolds)) {
            int stepSize = this.searchValues.get(0).size();
            int totalJobs = paramsToEval.size() / stepSize;
            latch = new CountDownLatch(totalJobs);
            for (int startPos = 0; startPos < paramsToEval.size(); startPos += stepSize) {
                final List subSet = paramsToEval.subList(startPos, startPos + stepSize);
                modelService.submit(new Runnable(){

                    /*
                     * WARNING - Removed try catching itself - possible behaviour change.
                     */
                    @Override
                    public void run() {
                        Classifier[] prevModels = null;
                        for (Classifier c : subSet) {
                            ClassificationModelEvaluation cme = GridSearch.this.trainModelsInParallel ? new ClassificationModelEvaluation(c, dataSet) : new ClassificationModelEvaluation(c, dataSet, threadPool);
                            cme.setKeepModels(true);
                            cme.setWarmModels(prevModels);
                            cme.addScorer(GridSearch.this.classificationTargetScore.clone());
                            if (GridSearch.this.reuseSameCVFolds) {
                                cme.evaluateCrossValidation(preFolded, trainCombinations);
                            } else {
                                cme.evaluateCrossValidation(GridSearch.this.folds);
                            }
                            prevModels = cme.getKeptModels();
                            PriorityQueue priorityQueue = bestModels;
                            synchronized (priorityQueue) {
                                bestModels.add(cme);
                            }
                        }
                        latch.countDown();
                    }
                });
            }
        } else {
            latch = new CountDownLatch(paramsToEval.size());
            for (final Classifier toTrain : paramsToEval) {
                modelService.submit(new Runnable(){

                    /*
                     * WARNING - Removed try catching itself - possible behaviour change.
                     */
                    @Override
                    public void run() {
                        ClassificationModelEvaluation cme = GridSearch.this.trainModelsInParallel ? new ClassificationModelEvaluation(toTrain, dataSet) : new ClassificationModelEvaluation(toTrain, dataSet, threadPool);
                        cme.addScorer(GridSearch.this.classificationTargetScore.clone());
                        if (GridSearch.this.reuseSameCVFolds) {
                            cme.evaluateCrossValidation(preFolded, trainCombinations);
                        } else {
                            cme.evaluateCrossValidation(GridSearch.this.folds);
                        }
                        PriorityQueue priorityQueue = bestModels;
                        synchronized (priorityQueue) {
                            bestModels.add(cme);
                        }
                        latch.countDown();
                    }
                });
            }
        }
        try {
            latch.await();
            Classifier bestClassifier = bestModels.peek().getClassifier();
            if (this.trainFinalModel) {
                if (this.useWarmStarts && bestClassifier instanceof WarmClassifier && !((WarmClassifier)bestClassifier).warmFromSameDataOnly()) {
                    WarmClassifier wc = (WarmClassifier)bestClassifier;
                    if (threadPool instanceof FakeExecutor) {
                        wc.trainC(dataSet, (Classifier)wc.clone());
                    } else {
                        wc.trainC(dataSet, (Classifier)wc.clone(), threadPool);
                    }
                } else if (threadPool instanceof FakeExecutor) {
                    bestClassifier.trainC(dataSet);
                } else {
                    bestClassifier.trainC(dataSet, threadPool);
                }
            }
            this.trainedClassifier = bestClassifier;
        }
        catch (InterruptedException ex) {
            Logger.getLogger(GridSearch.class.getName()).log(Level.SEVERE, null, ex);
        }
    }

    @Override
    public void trainC(ClassificationDataSet dataSet) {
        this.trainC(dataSet, new FakeExecutor());
    }

    @Override
    public GridSearch clone() {
        return new GridSearch(this);
    }

    private boolean incrementCombination(int[] setTo) {
        setTo[0] = setTo[0] + 1;
        int carryPos = 0;
        while (carryPos < setTo.length - 1 && setTo[carryPos] >= this.searchValues.get(carryPos).size()) {
            setTo[carryPos] = 0;
            int n = ++carryPos;
            setTo[n] = setTo[n] + 1;
        }
        return setTo[setTo.length - 1] >= this.searchValues.get(setTo.length - 1).size();
    }

    private void setParameters(int[] setTo) {
        for (int i = 0; i < setTo.length; ++i) {
            Parameter param = (Parameter)this.searchParams.get(i);
            if (param instanceof DoubleParameter) {
                ((DoubleParameter)param).setValue(this.searchValues.get(i).get(setTo[i]));
                continue;
            }
            if (!(param instanceof IntParameter)) continue;
            ((IntParameter)param).setValue(this.searchValues.get(i).get(setTo[i]).intValue());
        }
    }
}

