/*
 * Decompiled with CFR 0.152.
 */
package moa.classifiers.meta.imbalanced;

import com.github.javacliparser.FlagOption;
import com.github.javacliparser.FloatOption;
import com.github.javacliparser.IntOption;
import com.yahoo.labs.samoa.instances.Instance;
import java.util.ArrayList;
import java.util.Random;
import moa.capabilities.CapabilitiesHandler;
import moa.capabilities.Capability;
import moa.capabilities.ImmutableCapabilities;
import moa.classifiers.AbstractClassifier;
import moa.classifiers.Classifier;
import moa.classifiers.MultiClassClassifier;
import moa.classifiers.core.driftdetection.ADWIN;
import moa.core.DoubleVector;
import moa.core.Measurement;
import moa.core.MiscUtils;
import moa.core.Utils;
import moa.options.ClassOption;

public class OnlineCSB2
extends AbstractClassifier
implements MultiClassClassifier,
CapabilitiesHandler {
    private static final long serialVersionUID = 1L;
    public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l', "Classifier to train.", Classifier.class, "meta.AdaptiveRandomForest");
    public IntOption ensembleSizeOption = new IntOption("ensembleSize", 's', "The size of the ensemble.", 10, 1, Integer.MAX_VALUE);
    public FloatOption costPositiveOption = new FloatOption("costPositive", 'p', "The cost of misclassifying a positive sample.", 1.0, 0.1, 1.0);
    public FloatOption costNegativeOption = new FloatOption("costNegative", 'n', "The cost of misclassifying a negative sample.", 0.1, 0.1, 1.0);
    public FlagOption disableDriftDetectionOption = new FlagOption("disableDriftDetection", 'd', "Should use ADWIN as drift detector?");
    protected Classifier baseLearner;
    protected int nEstimators;
    protected double costPositive;
    protected double costNegative;
    protected boolean driftDetection;
    protected ArrayList<Classifier> ensemble;
    protected ArrayList<ADWIN> adwinEnsemble;
    protected ArrayList<Double> lambdaFN;
    protected ArrayList<Double> lambdaFP;
    protected ArrayList<Double> lambdaSum;
    protected ArrayList<Double> lambdaSw;
    protected ArrayList<Double> epsilon;
    protected ArrayList<Double> wErr;

    @Override
    public String getPurposeString() {
        return "Online CSB2 is the online version of the ensemble learner CSB2";
    }

    @Override
    public void resetLearningImpl() {
        this.baseLearner = (Classifier)this.getPreparedClassOption(this.baseLearnerOption);
        this.baseLearner.resetLearning();
        this.nEstimators = this.ensembleSizeOption.getValue();
        this.costPositive = this.costPositiveOption.getValue();
        this.costNegative = this.costNegativeOption.getValue();
        this.driftDetection = !this.disableDriftDetectionOption.isSet();
        this.ensemble = new ArrayList();
        if (this.driftDetection) {
            this.adwinEnsemble = new ArrayList();
        }
        this.lambdaFN = new ArrayList();
        this.lambdaFP = new ArrayList();
        this.lambdaSum = new ArrayList();
        this.lambdaSw = new ArrayList();
        this.epsilon = new ArrayList();
        this.wErr = new ArrayList();
        for (int i = 0; i < this.nEstimators; ++i) {
            this.ensemble.add(this.baseLearner.copy());
            if (this.driftDetection) {
                this.adwinEnsemble.add(new ADWIN());
            }
            this.lambdaFP.add(0.0);
            this.lambdaFN.add(0.0);
            this.lambdaSum.add(0.0);
            this.lambdaSw.add(0.0);
            this.epsilon.add(0.0);
            this.wErr.add(0.0);
        }
        this.classifierRandom = new Random(this.randomSeed);
    }

    @Override
    public void trainOnInstanceImpl(Instance instance) {
        if (this.ensemble.isEmpty()) {
            this.resetLearningImpl();
        }
        this.adjustEnsembleSize(instance.numClasses());
        double lambda = 1.0;
        boolean changeDetected = false;
        for (int i = 0; i < this.ensemble.size(); ++i) {
            this.lambdaSum.set(i, this.lambdaSum.get(i) + lambda);
            double k = MiscUtils.poisson(lambda, this.classifierRandom);
            if (k > 0.0) {
                int b = 0;
                while ((double)b < k) {
                    this.ensemble.get(i).trainOnInstance(instance);
                    ++b;
                }
                if ((double)Utils.maxIndex(this.ensemble.get(i).getVotesForInstance(instance)) == instance.classValue()) {
                    this.epsilon.set(i, this.lambdaSw.get(i) / this.lambdaSum.get(i));
                    this.wErr.set(i, (this.lambdaFP.get(i) + this.lambdaFN.get(i)) / this.lambdaSum.get(i));
                    if (this.epsilon.get(i) + this.wErr.get(i) != 0.0 && this.epsilon.get(i) != 1.0) {
                        lambda = this.epsilon.get(i) / ((1.0 - this.epsilon.get(i)) * (this.epsilon.get(i) + this.wErr.get(i)));
                    }
                } else if (Utils.maxIndex(this.ensemble.get(i).getVotesForInstance(instance)) == 0 && instance.classValue() == 1.0) {
                    this.lambdaFP.set(i, this.lambdaFP.get(i) + this.costPositive * lambda);
                    this.lambdaSw.set(i, this.lambdaSw.get(i) + lambda);
                    this.epsilon.set(i, this.lambdaSw.get(i) / this.lambdaSum.get(i));
                    this.wErr.set(i, (this.lambdaFP.get(i) + this.lambdaFN.get(i)) / this.lambdaSum.get(i));
                    lambda = this.costPositive * lambda / (this.epsilon.get(i) * this.wErr.get(i));
                } else {
                    this.lambdaFN.set(i, this.lambdaFN.get(i) + this.costPositive * lambda);
                    this.lambdaSw.set(i, this.lambdaSw.get(i) + lambda);
                    this.epsilon.set(i, this.lambdaSw.get(i) / this.lambdaSum.get(i));
                    this.wErr.set(i, (this.lambdaFP.get(i) + this.lambdaFN.get(i)) / this.lambdaSum.get(i));
                    lambda = this.costNegative * lambda / (this.epsilon.get(i) * this.wErr.get(i));
                }
            }
            if (!this.driftDetection) continue;
            double pred = Utils.maxIndex(this.ensemble.get(i).getVotesForInstance(instance));
            double errorEstimation = this.adwinEnsemble.get(i).getEstimation();
            double inputValue = pred == instance.classValue() ? 1.0 : 0.0;
            boolean resInput = this.adwinEnsemble.get(i).setInput(inputValue);
            if (!resInput || !(this.adwinEnsemble.get(i).getEstimation() > errorEstimation)) continue;
            changeDetected = true;
        }
        if (changeDetected && this.driftDetection) {
            double maxThreshold = 0.0;
            int iMax = -1;
            for (int i = 0; i < this.ensemble.size(); ++i) {
                if (!(maxThreshold < this.adwinEnsemble.get(i).getEstimation())) continue;
                maxThreshold = this.adwinEnsemble.get(i).getEstimation();
                iMax = i;
            }
            if (iMax != -1) {
                this.ensemble.get(iMax).resetLearning();
                this.adwinEnsemble.set(iMax, new ADWIN());
            }
        }
    }

    @Override
    public double[] getVotesForInstance(Instance instance) {
        Instance testInstance = instance.copy();
        DoubleVector combinedVote = new DoubleVector();
        for (int i = 0; i < this.ensemble.size(); ++i) {
            DoubleVector vote = new DoubleVector(this.ensemble.get(i).getVotesForInstance(testInstance));
            if (!(vote.sumOfValues() > 0.0)) continue;
            for (int v = 0; v < vote.numValues(); ++v) {
                vote.setValue(v, vote.getValue(v) * Math.log((1.0 - this.epsilon.get(i)) / this.epsilon.get(i)));
            }
            vote.normalize();
            combinedVote.addValues(vote);
        }
        return combinedVote.getArrayRef();
    }

    @Override
    public boolean isRandomizable() {
        return true;
    }

    @Override
    public void getModelDescription(StringBuilder arg0, int arg1) {
    }

    @Override
    protected Measurement[] getModelMeasurementsImpl() {
        return null;
    }

    protected void adjustEnsembleSize(int nClasses) {
        if (nClasses > this.nEstimators) {
            for (int i = this.nEstimators; i < nClasses; ++i) {
                this.ensemble.add(this.baseLearner.copy());
                ++this.nEstimators;
                if (this.driftDetection) {
                    this.adwinEnsemble.add(new ADWIN());
                }
                this.lambdaFP.add(0.0);
                this.lambdaFN.add(0.0);
                this.lambdaSum.add(0.0);
                this.lambdaSw.add(0.0);
                this.epsilon.add(0.0);
                this.wErr.add(0.0);
            }
        }
    }

    @Override
    public ImmutableCapabilities defineImmutableCapabilities() {
        if (this.getClass() == OnlineCSB2.class) {
            return new ImmutableCapabilities(Capability.VIEW_STANDARD, Capability.VIEW_LITE);
        }
        return new ImmutableCapabilities(Capability.VIEW_STANDARD);
    }
}

