/*
 * Decompiled with CFR 0.152.
 */
package jsat.parameters;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.PriorityQueue;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import jsat.DataSet;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.ClassificationModelEvaluation;
import jsat.classifiers.Classifier;
import jsat.distributions.Distribution;
import jsat.exceptions.FailedToFitException;
import jsat.parameters.DoubleParameter;
import jsat.parameters.IntParameter;
import jsat.parameters.ModelSearch;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.regression.RegressionDataSet;
import jsat.regression.RegressionModelEvaluation;
import jsat.regression.Regressor;
import jsat.utils.FakeExecutor;
import jsat.utils.random.XORWOW;

public class RandomSearch
extends ModelSearch {
    private int trials = 25;
    private List<Distribution> searchValues;

    public RandomSearch(Regressor baseRegressor, int folds) {
        super(baseRegressor, folds);
        this.searchValues = new ArrayList<Distribution>();
    }

    public RandomSearch(Classifier baseClassifier, int folds) {
        super(baseClassifier, folds);
        this.searchValues = new ArrayList<Distribution>();
    }

    public RandomSearch(RandomSearch toCopy) {
        super(toCopy);
        this.trials = toCopy.trials;
        this.searchValues = new ArrayList<Distribution>(toCopy.searchValues.size());
        for (Distribution d : toCopy.searchValues) {
            this.searchValues.add(d.clone());
        }
    }

    public int autoAddParameters(DataSet data) {
        Parameterized obj = this.baseClassifier != null ? (Parameterized)((Object)this.baseClassifier) : (Parameterized)((Object)this.baseRegressor);
        int totalParms = 0;
        for (Parameter param : obj.getParameters()) {
            Distribution dist;
            if (param instanceof DoubleParameter) {
                dist = ((DoubleParameter)param).getGuess(data);
                if (dist == null) continue;
                this.addParameter((DoubleParameter)param, dist);
                ++totalParms;
                continue;
            }
            if (!(param instanceof IntParameter) || (dist = ((IntParameter)param).getGuess(data)) == null) continue;
            this.addParameter((IntParameter)param, dist);
            ++totalParms;
        }
        return totalParms;
    }

    public void setTrials(int trials) {
        if (trials < 1) {
            throw new IllegalArgumentException("number of trials must be positive, not " + trials);
        }
        this.trials = trials;
    }

    public int getTrials() {
        return this.trials;
    }

    public void addParameter(DoubleParameter param, Distribution dist) {
        if (param == null) {
            throw new IllegalArgumentException("null not allowed for parameter");
        }
        this.searchParams.add(param);
        this.searchValues.add(dist.clone());
    }

    public void addParameter(IntParameter param, Distribution dist) {
        if (param == null) {
            throw new IllegalArgumentException("null not allowed for parameter");
        }
        this.searchParams.add(param);
        this.searchValues.add(dist.clone());
    }

    public void addParameter(String name, Distribution dist) {
        Parameter param = this.getParameterByName(name);
        if (param instanceof DoubleParameter) {
            this.addParameter((DoubleParameter)param, dist);
        } else if (param instanceof IntParameter) {
            this.addParameter((IntParameter)param, dist);
        } else {
            throw new IllegalArgumentException("Parameter " + name + " is not for double or int values");
        }
    }

    @Override
    public void trainC(final ClassificationDataSet dataSet, final ExecutorService threadPool) {
        ArrayList<ClassificationDataSet> trainCombinations;
        List<ClassificationDataSet> preFolded;
        final PriorityQueue<ClassificationModelEvaluation> bestModels = new PriorityQueue<ClassificationModelEvaluation>(this.folds, new Comparator<ClassificationModelEvaluation>(){

            @Override
            public int compare(ClassificationModelEvaluation t, ClassificationModelEvaluation t1) {
                double v0 = t.getScoreStats(RandomSearch.this.classificationTargetScore).getMean();
                double v1 = t1.getScoreStats(RandomSearch.this.classificationTargetScore).getMean();
                int order = RandomSearch.this.classificationTargetScore.lowerIsBetter() ? 1 : -1;
                return order * Double.compare(v0, v1);
            }
        });
        ArrayList<Classifier> paramsToEval = new ArrayList<Classifier>();
        XORWOW rand = new XORWOW();
        for (int trial = 0; trial < this.trials; ++trial) {
            for (int i = 0; i < this.searchParams.size(); ++i) {
                double sampledValue = this.searchValues.get(i).invCdf(((Random)rand).nextDouble());
                Parameter param = (Parameter)this.searchParams.get(i);
                if (param instanceof DoubleParameter) {
                    ((DoubleParameter)param).setValue(sampledValue);
                    continue;
                }
                if (!(param instanceof IntParameter)) continue;
                ((IntParameter)param).setValue((int)Math.round(sampledValue));
            }
            paramsToEval.add(this.baseClassifier.clone());
        }
        ExecutorService modelService = this.trainModelsInParallel && threadPool != null ? threadPool : new FakeExecutor();
        if (this.reuseSameCVFolds) {
            preFolded = dataSet.cvSet(this.folds);
            trainCombinations = new ArrayList<ClassificationDataSet>(preFolded.size());
            for (int i = 0; i < preFolded.size(); ++i) {
                trainCombinations.add(ClassificationDataSet.comineAllBut(preFolded, i));
            }
        } else {
            preFolded = null;
            trainCombinations = null;
        }
        final CountDownLatch latch = new CountDownLatch(paramsToEval.size());
        for (final Classifier c : paramsToEval) {
            modelService.submit(new Runnable(){

                /*
                 * WARNING - Removed try catching itself - possible behaviour change.
                 */
                @Override
                public void run() {
                    ClassificationModelEvaluation cme = RandomSearch.this.trainModelsInParallel ? new ClassificationModelEvaluation(c, dataSet) : new ClassificationModelEvaluation(c, dataSet, threadPool);
                    cme.addScorer(RandomSearch.this.classificationTargetScore.clone());
                    if (RandomSearch.this.reuseSameCVFolds) {
                        cme.evaluateCrossValidation(preFolded, trainCombinations);
                    } else {
                        cme.evaluateCrossValidation(RandomSearch.this.folds);
                    }
                    PriorityQueue priorityQueue = bestModels;
                    synchronized (priorityQueue) {
                        bestModels.add(cme);
                    }
                    latch.countDown();
                }
            });
        }
        try {
            latch.await();
            Classifier bestClassifier = bestModels.peek().getClassifier();
            if (this.trainFinalModel) {
                if (threadPool instanceof FakeExecutor) {
                    bestClassifier.trainC(dataSet);
                } else {
                    bestClassifier.trainC(dataSet, threadPool);
                }
            }
            this.trainedClassifier = bestClassifier;
        }
        catch (InterruptedException ex) {
            throw new FailedToFitException(ex);
        }
    }

    @Override
    public void trainC(ClassificationDataSet dataSet) {
        this.trainC(dataSet, null);
    }

    @Override
    public void train(final RegressionDataSet dataSet, final ExecutorService threadPool) {
        ArrayList<RegressionDataSet> trainCombinations;
        List<RegressionDataSet> preFolded;
        final PriorityQueue<RegressionModelEvaluation> bestModels = new PriorityQueue<RegressionModelEvaluation>(this.folds, new Comparator<RegressionModelEvaluation>(){

            @Override
            public int compare(RegressionModelEvaluation t, RegressionModelEvaluation t1) {
                double v0 = t.getScoreStats(RandomSearch.this.regressionTargetScore).getMean();
                double v1 = t1.getScoreStats(RandomSearch.this.regressionTargetScore).getMean();
                int order = RandomSearch.this.regressionTargetScore.lowerIsBetter() ? 1 : -1;
                return order * Double.compare(v0, v1);
            }
        });
        ArrayList<Regressor> paramsToEval = new ArrayList<Regressor>();
        XORWOW rand = new XORWOW();
        for (int trial = 0; trial < this.trials; ++trial) {
            for (int i = 0; i < this.searchParams.size(); ++i) {
                double sampledValue = this.searchValues.get(i).invCdf(((Random)rand).nextDouble());
                Parameter param = (Parameter)this.searchParams.get(i);
                if (param instanceof DoubleParameter) {
                    ((DoubleParameter)param).setValue(sampledValue);
                    continue;
                }
                if (!(param instanceof IntParameter)) continue;
                ((IntParameter)param).setValue((int)Math.round(sampledValue));
            }
            paramsToEval.add(this.baseRegressor.clone());
        }
        ExecutorService modelService = this.trainModelsInParallel && threadPool != null ? threadPool : new FakeExecutor();
        if (this.reuseSameCVFolds) {
            preFolded = dataSet.cvSet(this.folds);
            trainCombinations = new ArrayList<RegressionDataSet>(preFolded.size());
            for (int i = 0; i < preFolded.size(); ++i) {
                trainCombinations.add(RegressionDataSet.comineAllBut(preFolded, i));
            }
        } else {
            preFolded = null;
            trainCombinations = null;
        }
        final CountDownLatch latch = new CountDownLatch(paramsToEval.size());
        for (final Regressor r : paramsToEval) {
            modelService.submit(new Runnable(){

                /*
                 * WARNING - Removed try catching itself - possible behaviour change.
                 */
                @Override
                public void run() {
                    RegressionModelEvaluation cme = RandomSearch.this.trainModelsInParallel ? new RegressionModelEvaluation(r, dataSet) : new RegressionModelEvaluation(r, dataSet, threadPool);
                    cme.addScorer(RandomSearch.this.regressionTargetScore.clone());
                    if (RandomSearch.this.reuseSameCVFolds) {
                        cme.evaluateCrossValidation(preFolded, trainCombinations);
                    } else {
                        cme.evaluateCrossValidation(RandomSearch.this.folds);
                    }
                    PriorityQueue priorityQueue = bestModels;
                    synchronized (priorityQueue) {
                        bestModels.add(cme);
                    }
                    latch.countDown();
                }
            });
        }
        try {
            latch.await();
            Regressor bestRegressor = bestModels.peek().getRegressor();
            if (this.trainFinalModel) {
                if (threadPool instanceof FakeExecutor) {
                    bestRegressor.train(dataSet);
                } else {
                    bestRegressor.train(dataSet, threadPool);
                }
            }
            this.trainedRegressor = bestRegressor;
        }
        catch (InterruptedException ex) {
            throw new FailedToFitException(ex);
        }
    }

    @Override
    public void train(RegressionDataSet dataSet) {
        this.train(dataSet, null);
    }

    @Override
    public RandomSearch clone() {
        return new RandomSearch(this);
    }
}

