/*
 * Decompiled with CFR 0.152.
 */
package jsat.regression;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.classifiers.ClassificationModelEvaluation;
import jsat.classifiers.DataPoint;
import jsat.datatransform.DataTransformProcess;
import jsat.exceptions.UntrainedModelException;
import jsat.math.OnLineStatistics;
import jsat.regression.RegressionDataSet;
import jsat.regression.Regressor;
import jsat.regression.WarmRegressor;
import jsat.regression.evaluation.RegressionScore;
import jsat.utils.SystemInfo;
import jsat.utils.random.RandomUtil;

public class RegressionModelEvaluation {
    private Regressor regressor;
    private RegressionDataSet dataSet;
    private ExecutorService threadpool;
    private OnLineStatistics sqrdErrorStats;
    private long totalTrainingTime = 0L;
    private long totalClassificationTime = 0L;
    private DataTransformProcess dtp;
    private Map<RegressionScore, OnLineStatistics> scoreMap;
    private boolean keepModels = false;
    private Regressor[] keptModels;
    private Regressor[] warmModels;

    public RegressionModelEvaluation(Regressor regressor, RegressionDataSet dataSet, ExecutorService threadpool) {
        this.regressor = regressor;
        this.dataSet = dataSet;
        this.threadpool = threadpool;
        this.dtp = new DataTransformProcess();
        this.scoreMap = new LinkedHashMap<RegressionScore, OnLineStatistics>();
    }

    public RegressionModelEvaluation(Regressor regressor, RegressionDataSet dataSet) {
        this(regressor, dataSet, null);
    }

    public void setKeepModels(boolean keepModels) {
        this.keepModels = keepModels;
    }

    public boolean isKeepModels() {
        return this.keepModels;
    }

    public Regressor[] getKeptModels() {
        return this.keptModels;
    }

    public void setWarmModels(Regressor ... warmModels) {
        this.warmModels = warmModels;
    }

    public void setDataTransformProcess(DataTransformProcess dtp) {
        this.dtp = dtp.clone();
    }

    public void evaluateCrossValidation(int folds) {
        this.evaluateCrossValidation(folds, RandomUtil.getRandom());
    }

    public void evaluateCrossValidation(int folds, Random rand) {
        if (folds < 2) {
            throw new UntrainedModelException("Model could not be evaluated because " + folds + " is < 2, and not valid for cross validation");
        }
        List<RegressionDataSet> lcds = this.dataSet.cvSet(folds, rand);
        this.evaluateCrossValidation(lcds);
    }

    public void evaluateCrossValidation(List<RegressionDataSet> lcds) {
        ArrayList<RegressionDataSet> trainCombinations = new ArrayList<RegressionDataSet>(lcds.size());
        for (int i = 0; i < lcds.size(); ++i) {
            trainCombinations.add(RegressionDataSet.comineAllBut(lcds, i));
        }
        this.evaluateCrossValidation(lcds, trainCombinations);
    }

    public void evaluateCrossValidation(List<RegressionDataSet> lcds, List<RegressionDataSet> trainCombinations) {
        this.sqrdErrorStats = new OnLineStatistics();
        this.totalClassificationTime = 0L;
        this.totalTrainingTime = 0L;
        for (int i = 0; i < lcds.size(); ++i) {
            RegressionDataSet trainSet = trainCombinations.get(i);
            RegressionDataSet testSet = lcds.get(i);
            this.evaluationWork(trainSet, testSet, i);
        }
    }

    public void evaluateTestSet(RegressionDataSet testSet) {
        this.sqrdErrorStats = new OnLineStatistics();
        this.totalClassificationTime = 0L;
        this.totalTrainingTime = 0L;
        this.evaluationWork(this.dataSet, testSet, 0);
    }

    private void evaluationWork(RegressionDataSet trainSet, RegressionDataSet testSet, int index) {
        CountDownLatch latch;
        trainSet = trainSet.shallowClone();
        DataTransformProcess curProccess = this.dtp.clone();
        curProccess.learnApplyTransforms(trainSet);
        long startTrain = System.currentTimeMillis();
        Regressor regressorTouse = this.regressor.clone();
        if (this.warmModels != null && regressorTouse instanceof WarmRegressor) {
            WarmRegressor wr = (WarmRegressor)regressorTouse;
            if (this.threadpool != null) {
                wr.train(trainSet, this.warmModels[index], this.threadpool);
            } else {
                wr.train(trainSet, this.warmModels[index]);
            }
        } else if (this.threadpool != null) {
            regressorTouse.train(trainSet, this.threadpool);
        } else {
            regressorTouse.train(trainSet);
        }
        this.totalTrainingTime += System.currentTimeMillis() - startTrain;
        if (this.keptModels != null) {
            this.keptModels[index] = regressorTouse;
        }
        HashMap<RegressionScore, RegressionScore> scoresToUpdate = new HashMap<RegressionScore, RegressionScore>();
        for (Map.Entry<RegressionScore, OnLineStatistics> entry : this.scoreMap.entrySet()) {
            RegressionScore score = entry.getKey().clone();
            score.prepare();
            scoresToUpdate.put(score, score);
        }
        if (testSet.getSampleSize() < SystemInfo.LogicalCores || this.threadpool == null) {
            latch = new CountDownLatch(1);
            new Evaluator(testSet, curProccess, 0, testSet.getSampleSize(), scoresToUpdate, regressorTouse, latch).run();
        } else {
            latch = new CountDownLatch(SystemInfo.LogicalCores);
            int blockSize = testSet.getSampleSize() / SystemInfo.LogicalCores;
            int extra = testSet.getSampleSize() % SystemInfo.LogicalCores;
            int start = 0;
            while (start < testSet.getSampleSize()) {
                int end = start + blockSize;
                if (extra-- > 0) {
                    ++end;
                }
                this.threadpool.submit(new Evaluator(testSet, curProccess, start, end, scoresToUpdate, regressorTouse, latch));
                start = end;
            }
        }
        try {
            latch.await();
            for (Map.Entry<RegressionScore, OnLineStatistics> entry : this.scoreMap.entrySet()) {
                RegressionScore score = entry.getKey().clone();
                score.prepare();
                score.addResults((RegressionScore)scoresToUpdate.get(score));
                entry.getValue().add(score.getScore());
            }
        }
        catch (InterruptedException ex) {
            Logger.getLogger(ClassificationModelEvaluation.class.getName()).log(Level.SEVERE, null, ex);
        }
    }

    public void addScorer(RegressionScore scorer) {
        this.scoreMap.put(scorer, new OnLineStatistics());
    }

    public OnLineStatistics getScoreStats(RegressionScore score) {
        return this.scoreMap.get(score);
    }

    public void prettyPrintRegressionScores() {
        int nameLength = 10;
        for (Map.Entry<RegressionScore, OnLineStatistics> entry : this.scoreMap.entrySet()) {
            nameLength = Math.max(nameLength, entry.getKey().getName().length() + 2);
        }
        String pfx = "%-" + nameLength;
        for (Map.Entry<RegressionScore, OnLineStatistics> entry : this.scoreMap.entrySet()) {
            System.out.printf(pfx + "s %-5f (%-5f)\n", entry.getKey().getName(), entry.getValue().getMean(), entry.getValue().getStandardDeviation());
        }
    }

    public double getMinError() {
        return this.sqrdErrorStats.getMin();
    }

    public double getMaxError() {
        return this.sqrdErrorStats.getMax();
    }

    public double getMeanError() {
        return this.sqrdErrorStats.getMean();
    }

    public double getErrorStndDev() {
        return this.sqrdErrorStats.getStandardDeviation();
    }

    public long getTotalTrainingTime() {
        return this.totalTrainingTime;
    }

    public long getTotalClassificationTime() {
        return this.totalClassificationTime;
    }

    public Regressor getRegressor() {
        return this.regressor;
    }

    private class Evaluator
    implements Runnable {
        RegressionDataSet testSet;
        DataTransformProcess curProccess;
        int start;
        int end;
        CountDownLatch latch;
        long localPredictionTime;
        Regressor toUse;
        final Map<RegressionScore, RegressionScore> scoresToUpdate;

        public Evaluator(RegressionDataSet testSet, DataTransformProcess curProccess, int start, int end, Map<RegressionScore, RegressionScore> scoresToUpdate, Regressor toUse, CountDownLatch latch) {
            this.testSet = testSet;
            this.curProccess = curProccess;
            this.start = start;
            this.end = end;
            this.latch = latch;
            this.localPredictionTime = 0L;
            this.toUse = toUse;
            this.scoresToUpdate = scoresToUpdate;
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        @Override
        public void run() {
            try {
                HashSet<RegressionScore> localScores = new HashSet<RegressionScore>();
                for (Map.Entry<RegressionScore, RegressionScore> entry : this.scoresToUpdate.entrySet()) {
                    localScores.add(entry.getKey().clone());
                }
                for (int i = this.start; i < this.end; ++i) {
                    DataPoint di = this.testSet.getDataPoint(i);
                    double trueVal = this.testSet.getTargetValue(i);
                    DataPoint tranDP = this.curProccess.transform(di);
                    long startTime = System.currentTimeMillis();
                    double predVal = this.toUse.regress(tranDP);
                    this.localPredictionTime += System.currentTimeMillis() - startTime;
                    double sqrdError = Math.pow(trueVal - predVal, 2.0);
                    for (RegressionScore score : localScores) {
                        score.addResult(predVal, trueVal, di.getWeight());
                    }
                    OnLineStatistics onLineStatistics = RegressionModelEvaluation.this.sqrdErrorStats;
                    synchronized (onLineStatistics) {
                        RegressionModelEvaluation.this.sqrdErrorStats.add(sqrdError, di.getWeight());
                        continue;
                    }
                }
                OnLineStatistics onLineStatistics = RegressionModelEvaluation.this.sqrdErrorStats;
                synchronized (onLineStatistics) {
                    RegressionModelEvaluation.this.totalClassificationTime = RegressionModelEvaluation.this.totalClassificationTime + this.localPredictionTime;
                    for (RegressionScore score : localScores) {
                        this.scoresToUpdate.get(score).addResults(score);
                    }
                }
                this.latch.countDown();
            }
            catch (Exception ex) {
                ex.printStackTrace();
            }
        }
    }
}

