/*
 * Decompiled with CFR 0.152.
 */
package jsat.regression;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ExecutorService;
import jsat.classifiers.DataPoint;
import jsat.classifiers.DataPointPair;
import jsat.exceptions.UntrainedModelException;
import jsat.linear.Vec;
import jsat.math.Function;
import jsat.math.FunctionBase;
import jsat.math.rootfinding.Zeroin;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.regression.RegressionDataSet;
import jsat.regression.Regressor;
import jsat.utils.DoubleList;
import jsat.utils.FakeExecutor;
import jsat.utils.ListUtils;

public class StochasticGradientBoosting
implements Regressor,
Parameterized {
    private static final long serialVersionUID = -2855154397476855293L;
    public static final double DEFAULT_TRAINING_PROPORTION = 0.5;
    public static final double DEFAULT_LEARNING_RATE = 0.1;
    private double trainingProportion;
    private Regressor weakLearner;
    private Regressor strongLearner;
    private List<Regressor> F;
    private List<Double> coef;
    private double learningRate;
    private int maxIterations;

    public StochasticGradientBoosting(Regressor strongLearner, Regressor weakLearner, int maxIterations, double learningRate, double trainingPortion) {
        this.trainingProportion = trainingPortion;
        this.strongLearner = strongLearner;
        this.weakLearner = weakLearner;
        this.learningRate = learningRate;
        this.maxIterations = maxIterations;
    }

    public StochasticGradientBoosting(Regressor weakLearner, int maxIterations, double learningRate, double trainingPortion) {
        this(null, weakLearner, maxIterations, learningRate, trainingPortion);
    }

    public StochasticGradientBoosting(Regressor weakLearner, int maxIterations, double learningRate) {
        this(weakLearner, maxIterations, learningRate, 0.5);
    }

    public StochasticGradientBoosting(Regressor weakLearner, int maxIterations) {
        this(weakLearner, maxIterations, 0.1);
    }

    public void setMaxIterations(int maxIterations) {
        this.maxIterations = maxIterations;
    }

    public int getMaxIterations() {
        return this.maxIterations;
    }

    public void setLearningRate(double learningRate) {
        if (learningRate > 1.0 || learningRate <= 0.0 || Double.isNaN(learningRate)) {
            throw new ArithmeticException("Invalid learning rate");
        }
        this.learningRate = learningRate;
    }

    public double getLearningRate() {
        return this.learningRate;
    }

    public void setTrainingProportion(double trainingProportion) {
        if (trainingProportion > 1.0 || trainingProportion <= 0.0 || Double.isNaN(trainingProportion)) {
            throw new ArithmeticException("Training Proportion is invalid");
        }
        this.trainingProportion = trainingProportion;
    }

    public double getTrainingProportion() {
        return this.trainingProportion;
    }

    @Override
    public double regress(DataPoint data) {
        if (this.F == null || this.F.isEmpty()) {
            throw new UntrainedModelException();
        }
        double result = 0.0;
        for (int i = 0; i < this.F.size(); ++i) {
            result += this.F.get(i).regress(data) * this.coef.get(i);
        }
        return result;
    }

    @Override
    public void train(RegressionDataSet dataSet, ExecutorService threadPool) {
        Regressor lastF;
        List<DataPointPair<Double>> backingResidsList = dataSet.getAsDPPList();
        this.F = new ArrayList<Regressor>(this.maxIterations);
        this.coef = new DoubleList(this.maxIterations);
        Regressor regressor = lastF = this.strongLearner == null ? this.weakLearner.clone() : this.strongLearner.clone();
        if (threadPool == null || threadPool instanceof FakeExecutor) {
            lastF.train(dataSet);
        } else {
            lastF.train(dataSet, threadPool);
        }
        this.F.add(lastF);
        this.coef.add(this.learningRate * this.getMinimizingErrorConst(backingResidsList, lastF));
        double[] currPredictions = new double[dataSet.getSampleSize()];
        RegressionDataSet resids = RegressionDataSet.usingDPPList(backingResidsList);
        int randSampleSize = (int)Math.round((double)resids.getSampleSize() * this.trainingProportion);
        ArrayList<DataPointPair<Double>> randSampleList = new ArrayList<DataPointPair<Double>>(randSampleSize);
        Random rand = new Random();
        for (int iter = 0; iter < this.maxIterations; ++iter) {
            double lastCoef = this.coef.get(iter);
            lastF = this.F.get(iter);
            for (int j = 0; j < resids.getSampleSize(); ++j) {
                double lastFPred = lastF.regress(resids.getDataPoint(j));
                int n = j;
                currPredictions[n] = currPredictions[n] + lastCoef * lastFPred;
                resids.setTargetValue(j, dataSet.getTargetValue(j) - currPredictions[j]);
            }
            randSampleList.clear();
            ListUtils.randomSample(backingResidsList, randSampleList, randSampleSize, rand);
            Regressor h = this.weakLearner.clone();
            RegressionDataSet tmpDataSet = RegressionDataSet.usingDPPList(randSampleList);
            if (threadPool == null || threadPool instanceof FakeExecutor) {
                h.train(tmpDataSet);
            } else {
                h.train(tmpDataSet, threadPool);
            }
            double y = this.getMinimizingErrorConst(backingResidsList, h);
            this.F.add(h);
            this.coef.add(this.learningRate * y);
        }
    }

    private double getMinimizingErrorConst(List<DataPointPair<Double>> backingResidsList, Regressor h) {
        Function fhPrime = this.getDerivativeFunc(backingResidsList, h);
        Zeroin rf = new Zeroin();
        double y = rf.root(1.0E-4, 50, new double[]{-2.5, 2.5}, fhPrime, 0, 1.0);
        return y;
    }

    private Function getDerivativeFunc(final List<DataPointPair<Double>> backingResidsList, final Regressor h) {
        FunctionBase fhPrime = new FunctionBase(){
            private static final long serialVersionUID = -2211642040228795719L;

            @Override
            public double f(Vec x) {
                double c1 = x.get(0);
                double eps = 1.0E-5;
                double c1Pc2 = c1 * 2.0 - eps;
                double result = 0.0;
                for (DataPointPair dpp : backingResidsList) {
                    double hEst = h.regress(dpp.getDataPoint());
                    double target = (Double)dpp.getPair();
                    result += hEst * (c1Pc2 * hEst - 2.0 * target);
                }
                return result * eps;
            }
        };
        return fhPrime;
    }

    @Override
    public void train(RegressionDataSet dataSet) {
        this.train(dataSet, null);
    }

    @Override
    public boolean supportsWeightedData() {
        if (this.strongLearner != null) {
            return this.strongLearner.supportsWeightedData() && this.weakLearner.supportsWeightedData();
        }
        return this.weakLearner.supportsWeightedData();
    }

    @Override
    public StochasticGradientBoosting clone() {
        StochasticGradientBoosting clone = new StochasticGradientBoosting(this.weakLearner.clone(), this.maxIterations, this.learningRate, this.trainingProportion);
        if (this.F != null) {
            clone.F = new ArrayList<Regressor>(this.F.size());
            for (Regressor f : this.F) {
                clone.F.add(f.clone());
            }
        }
        if (this.coef != null) {
            clone.coef = new DoubleList(this.coef.size());
            Iterator<Serializable> iterator = this.coef.iterator();
            while (iterator.hasNext()) {
                double d = (Double)iterator.next();
                clone.coef.add(d);
            }
        }
        if (this.strongLearner != null) {
            clone.strongLearner = this.strongLearner.clone();
        }
        return clone;
    }

    @Override
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override
    public Parameter getParameter(String paramName) {
        return Parameter.toParameterMap(this.getParameters()).get(paramName);
    }
}

