/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.boosting;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.ExecutorService;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.linear.DenseVector;
import jsat.linear.Vec;
import jsat.regression.RegressionDataSet;
import jsat.regression.Regressor;

public class Stacking
implements Classifier,
Regressor {
    private static final long serialVersionUID = -6173323872903232074L;
    private int folds;
    private int weightsPerModel;
    private Classifier aggregatingClassifier;
    private List<Classifier> baseClassifiers;
    private Regressor aggregatingRegressor;
    private List<Regressor> baseRegressors;
    public static final int DEFAULT_FOLDS = 3;

    public Stacking(int folds, Classifier aggregatingClassifier, List<Classifier> baseClassifiers) {
        if (baseClassifiers.size() < 2) {
            throw new IllegalArgumentException("base classifiers must contain at least 2 elements, not " + baseClassifiers.size());
        }
        this.setFolds(folds);
        this.aggregatingClassifier = aggregatingClassifier;
        this.baseClassifiers = baseClassifiers;
        boolean allRegressors = aggregatingClassifier instanceof Regressor;
        for (Classifier cl : baseClassifiers) {
            if (cl instanceof Regressor) continue;
            allRegressors = false;
        }
        if (allRegressors) {
            this.aggregatingRegressor = (Regressor)((Object)aggregatingClassifier);
            this.baseRegressors = baseClassifiers;
        }
    }

    public Stacking(int folds, Classifier aggregatingClassifier, Classifier ... baseClassifiers) {
        this(folds, aggregatingClassifier, Arrays.asList(baseClassifiers));
    }

    public Stacking(Classifier aggregatingClassifier, List<Classifier> baseClassifiers) {
        this(3, aggregatingClassifier, baseClassifiers);
    }

    public Stacking(Classifier aggregatingClassifier, Classifier ... baseClassifiers) {
        this(3, aggregatingClassifier, baseClassifiers);
    }

    public Stacking(int folds, Regressor aggregatingRegressor, List<Regressor> baseRegressors) {
        this.setFolds(folds);
        this.aggregatingRegressor = aggregatingRegressor;
        this.baseRegressors = baseRegressors;
        boolean allClassifiers = aggregatingRegressor instanceof Classifier;
        for (Regressor reg : baseRegressors) {
            if (reg instanceof Classifier) continue;
            allClassifiers = false;
        }
        if (allClassifiers) {
            this.aggregatingClassifier = (Classifier)((Object)aggregatingRegressor);
            this.baseClassifiers = baseRegressors;
        }
    }

    public Stacking(int folds, Regressor aggregatingRegressor, Regressor ... baseRegressors) {
        this(folds, aggregatingRegressor, Arrays.asList(baseRegressors));
    }

    public Stacking(Regressor aggregatingRegressor, List<Regressor> baseRegressors) {
        this(3, aggregatingRegressor, baseRegressors);
    }

    public Stacking(Regressor aggregatingRegressor, Regressor ... baseRegressors) {
        this(3, aggregatingRegressor, baseRegressors);
    }

    public Stacking(Stacking toCopy) {
        this.folds = toCopy.folds;
        this.weightsPerModel = toCopy.weightsPerModel;
        if (toCopy.aggregatingClassifier != null) {
            this.aggregatingClassifier = toCopy.aggregatingClassifier.clone();
            this.baseClassifiers = new ArrayList<Classifier>(toCopy.baseClassifiers.size());
            for (Classifier bc : toCopy.baseClassifiers) {
                this.baseClassifiers.add(bc.clone());
            }
            if (toCopy.aggregatingRegressor == toCopy.aggregatingClassifier) {
                this.aggregatingRegressor = (Regressor)((Object)this.aggregatingClassifier);
                this.baseRegressors = this.baseClassifiers;
            }
        } else {
            this.aggregatingRegressor = toCopy.aggregatingRegressor.clone();
            this.baseRegressors = new ArrayList<Regressor>(toCopy.baseRegressors.size());
            for (Regressor br : toCopy.baseRegressors) {
                this.baseRegressors.add(br.clone());
            }
        }
    }

    public void setFolds(int folds) {
        if (folds < 1) {
            throw new IllegalArgumentException("Folds must be a positive integer, not " + folds);
        }
        this.folds = folds;
    }

    public int getFolds() {
        return this.folds;
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        DenseVector w = new DenseVector(this.weightsPerModel * this.baseClassifiers.size());
        if (this.weightsPerModel == 1) {
            for (int i = 0; i < this.baseClassifiers.size(); ++i) {
                ((Vec)w).set(i, this.baseClassifiers.get(i).classify(data).getProb(0) * 2.0 - 1.0);
            }
        } else {
            for (int i = 0; i < this.baseClassifiers.size(); ++i) {
                CategoricalResults pred = this.baseClassifiers.get(i).classify(data);
                for (int j = 0; j < this.weightsPerModel; ++j) {
                    ((Vec)w).set(i * this.weightsPerModel + j, pred.getProb(j));
                }
            }
        }
        return this.aggregatingClassifier.classify(new DataPoint(w));
    }

    @Override
    public void trainC(ClassificationDataSet dataSet, ExecutorService threadPool) {
        int models = this.baseClassifiers.size();
        int C = dataSet.getClassSize();
        this.weightsPerModel = C == 2 ? 1 : C;
        ClassificationDataSet metaSet = new ClassificationDataSet(this.weightsPerModel * models, new CategoricalData[0], dataSet.getPredicting());
        List<ClassificationDataSet> dataFolds = dataSet.cvSet(this.folds);
        for (ClassificationDataSet cds : dataFolds) {
            for (int i = 0; i < cds.getSampleSize(); ++i) {
                metaSet.addDataPoint((Vec)new DenseVector(this.weightsPerModel * models), cds.getDataPointCategory(i), cds.getDataPoint(i).getWeight());
            }
        }
        for (int c = 0; c < this.baseClassifiers.size(); ++c) {
            Classifier cl = this.baseClassifiers.get(c);
            int pos = 0;
            for (int f = 0; f < dataFolds.size(); ++f) {
                ClassificationDataSet train = ClassificationDataSet.comineAllBut(dataFolds, f);
                ClassificationDataSet test = dataFolds.get(f);
                if (threadPool == null) {
                    cl.trainC(train);
                } else {
                    cl.trainC(train, threadPool);
                }
                for (int i = 0; i < test.getSampleSize(); ++i) {
                    CategoricalResults pred = cl.classify(test.getDataPoint(i));
                    if (C == 2) {
                        metaSet.getDataPoint(pos).getNumericalValues().set(c, pred.getProb(0) * 2.0 - 1.0);
                    } else {
                        Vec toSet = metaSet.getDataPoint(pos).getNumericalValues();
                        for (int j = this.weightsPerModel * c; j < this.weightsPerModel * (c + 1); ++j) {
                            toSet.set(j, pred.getProb(j - this.weightsPerModel * c));
                        }
                    }
                    ++pos;
                }
            }
        }
        if (threadPool == null) {
            this.aggregatingClassifier.trainC(metaSet);
        } else {
            this.aggregatingClassifier.trainC(metaSet, threadPool);
        }
        if (this.folds != 1) {
            for (Classifier cl : this.baseClassifiers) {
                if (threadPool == null) {
                    cl.trainC(dataSet);
                    continue;
                }
                cl.trainC(dataSet, threadPool);
            }
        }
    }

    @Override
    public void trainC(ClassificationDataSet dataSet) {
        this.trainC(dataSet, null);
    }

    @Override
    public boolean supportsWeightedData() {
        if (this.aggregatingClassifier != null) {
            return this.aggregatingClassifier.supportsWeightedData();
        }
        return this.aggregatingRegressor.supportsWeightedData();
    }

    @Override
    public double regress(DataPoint data) {
        DenseVector w = new DenseVector(this.baseRegressors.size());
        for (int i = 0; i < this.baseRegressors.size(); ++i) {
            ((Vec)w).set(i, this.baseRegressors.get(i).regress(data));
        }
        return this.aggregatingRegressor.regress(new DataPoint(w));
    }

    @Override
    public void train(RegressionDataSet dataSet, ExecutorService threadPool) {
        int models = this.baseRegressors.size();
        this.weightsPerModel = 1;
        RegressionDataSet metaSet = new RegressionDataSet(models, new CategoricalData[0]);
        List<RegressionDataSet> dataFolds = dataSet.cvSet(this.folds);
        for (RegressionDataSet rds : dataFolds) {
            for (int i = 0; i < rds.getSampleSize(); ++i) {
                metaSet.addDataPoint(new DataPoint(new DenseVector(this.weightsPerModel * models), rds.getDataPoint(i).getWeight()), rds.getTargetValue(i));
            }
        }
        for (int c = 0; c < this.baseRegressors.size(); ++c) {
            Regressor reg = this.baseRegressors.get(c);
            int pos = 0;
            for (int f = 0; f < dataFolds.size(); ++f) {
                RegressionDataSet train = RegressionDataSet.comineAllBut(dataFolds, f);
                RegressionDataSet test = dataFolds.get(f);
                if (threadPool == null) {
                    reg.train(train);
                } else {
                    reg.train(train, threadPool);
                }
                for (int i = 0; i < test.getSampleSize(); ++i) {
                    double pred = reg.regress(test.getDataPoint(i));
                    metaSet.getDataPoint(pos++).getNumericalValues().set(c, pred);
                }
            }
        }
        if (threadPool == null) {
            this.aggregatingRegressor.train(metaSet);
        } else {
            this.aggregatingRegressor.train(metaSet, threadPool);
        }
        if (this.folds != 1) {
            for (Regressor reg : this.baseRegressors) {
                if (threadPool == null) {
                    reg.train(dataSet);
                    continue;
                }
                reg.train(dataSet, threadPool);
            }
        }
    }

    @Override
    public void train(RegressionDataSet dataSet) {
        this.train(dataSet, null);
    }

    @Override
    public Stacking clone() {
        return new Stacking(this);
    }
}

