/*
 * Decompiled with CFR 0.152.
 */
package moa.classifiers.meta.imbalanced;

import com.github.javacliparser.FlagOption;
import com.github.javacliparser.IntOption;
import com.github.javacliparser.MultiChoiceOption;
import com.yahoo.labs.samoa.instances.Instance;
import java.util.ArrayList;
import java.util.Random;
import moa.capabilities.CapabilitiesHandler;
import moa.capabilities.Capability;
import moa.capabilities.ImmutableCapabilities;
import moa.classifiers.AbstractClassifier;
import moa.classifiers.Classifier;
import moa.classifiers.MultiClassClassifier;
import moa.classifiers.core.driftdetection.ADWIN;
import moa.core.DoubleVector;
import moa.core.Measurement;
import moa.core.MiscUtils;
import moa.core.Utils;
import moa.options.ClassOption;

public class OnlineRUSBoost
extends AbstractClassifier
implements MultiClassClassifier,
CapabilitiesHandler {
    private static final long serialVersionUID = 1L;
    public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l', "Classifier to train.", Classifier.class, "meta.AdaptiveRandomForest");
    public IntOption ensembleSizeOption = new IntOption("ensembleSize", 's', "The size of the ensemble.", 10, 1, Integer.MAX_VALUE);
    public IntOption samplingRateOption = new IntOption("samplingRate", 'i', "The sampling rate of the positive instances.", 3, 1, 10);
    public MultiChoiceOption algorithmImplementationOption = new MultiChoiceOption("algorithmImplementation", 'a', "The implementation of RUSBoost to use.", new String[]{"Fixed class ration", "Fixed example distribution", "Fixed sampling rate"}, new String[]{"ClassRation", "ExampleDistribution", "SamplingRate"}, 0);
    public FlagOption disableDriftDetectionOption = new FlagOption("disableDriftDetection", 'd', "Should use ADWIN as drift detector?");
    protected Classifier baseLearner;
    protected int nEstimators;
    protected int samplingRate;
    protected int algorithmImplementation;
    protected boolean driftDetection;
    protected ArrayList<Classifier> ensemble;
    protected ArrayList<ADWIN> adwinEnsemble;
    protected ArrayList<Double> lambdaSc;
    protected ArrayList<Double> lambdaPos;
    protected ArrayList<Double> lambdaNeg;
    protected ArrayList<Double> lambdaSw;
    protected ArrayList<Double> epsilon;
    protected double nPositive;
    protected double nNegative;

    @Override
    public String getPurposeString() {
        return "Online RUSBoost is the adaptation of the ensemble learner to data streams.";
    }

    @Override
    public void resetLearningImpl() {
        this.baseLearner = (Classifier)this.getPreparedClassOption(this.baseLearnerOption);
        this.baseLearner.resetLearning();
        this.nEstimators = this.ensembleSizeOption.getValue();
        this.samplingRate = this.samplingRateOption.getValue();
        this.algorithmImplementation = this.algorithmImplementationOption.getChosenIndex();
        this.driftDetection = !this.disableDriftDetectionOption.isSet();
        this.ensemble = new ArrayList();
        if (this.driftDetection) {
            this.adwinEnsemble = new ArrayList();
        }
        this.lambdaSc = new ArrayList();
        this.lambdaPos = new ArrayList();
        this.lambdaNeg = new ArrayList();
        this.lambdaSw = new ArrayList();
        this.epsilon = new ArrayList();
        for (int i = 0; i < this.nEstimators; ++i) {
            this.ensemble.add(this.baseLearner.copy());
            if (this.driftDetection) {
                this.adwinEnsemble.add(new ADWIN());
            }
            this.lambdaSc.add(0.0);
            this.lambdaPos.add(0.0);
            this.lambdaNeg.add(0.0);
            this.lambdaSw.add(0.0);
            this.epsilon.add(0.0);
        }
        this.nPositive = 0.0;
        this.nNegative = 0.0;
        this.classifierRandom = new Random(this.randomSeed);
    }

    @Override
    public void trainOnInstanceImpl(Instance instance) {
        if (this.ensemble.isEmpty()) {
            this.resetLearningImpl();
        }
        this.adjustEnsembleSize(instance.numClasses());
        double lambda = 1.0;
        boolean changeDetected = false;
        for (int i = 0; i < this.ensemble.size(); ++i) {
            double k;
            if (instance.classValue() == 1.0) {
                this.lambdaPos.set(i, this.lambdaPos.get(i) + lambda);
                this.nPositive += 1.0;
            } else {
                this.lambdaNeg.set(i, this.lambdaNeg.get(i) + lambda);
                this.nNegative += 1.0;
            }
            double lambdaRus = 1.0;
            if (this.algorithmImplementation == 0) {
                if (instance.classValue() == 1.0) {
                    if (this.nNegative != 0.0) {
                        lambdaRus = lambda * ((this.lambdaPos.get(i) + this.lambdaNeg.get(i)) / (this.lambdaPos.get(i) + this.lambdaNeg.get(i) * ((double)this.samplingRate * (this.nPositive / this.nNegative))) * ((double)(this.samplingRate + 1) * this.nPositive / (this.nPositive + this.nNegative)));
                    }
                } else if (this.nPositive != 0.0) {
                    lambdaRus = lambda * ((this.lambdaPos.get(i) + this.lambdaNeg.get(i)) / (this.lambdaPos.get(i) + this.lambdaNeg.get(i) * (this.nNegative / (this.nPositive * (double)this.samplingRate))) * ((double)(this.samplingRate + 1) * this.nPositive / (this.nPositive + this.nNegative)));
                }
            } else if (this.algorithmImplementation == 1) {
                lambdaRus = instance.classValue() == 1.0 ? lambda * this.nPositive / (this.nPositive + this.nNegative) / (this.lambdaPos.get(i) / (this.lambdaPos.get(i) + this.lambdaNeg.get(i))) : lambda * (double)this.samplingRate * this.nPositive / (this.nPositive + this.nNegative) / (this.lambdaNeg.get(i) / (this.lambdaPos.get(i) + this.lambdaNeg.get(i)));
            } else if (this.algorithmImplementation == 2) {
                lambdaRus = instance.classValue() == 1.0 ? lambda : lambda / (double)this.samplingRate;
            }
            if ((k = (double)MiscUtils.poisson(lambdaRus, this.classifierRandom)) > 0.0) {
                int b = 0;
                while ((double)b < k) {
                    this.ensemble.get(i).trainOnInstance(instance);
                    ++b;
                }
            }
            if ((double)Utils.maxIndex(this.ensemble.get(i).getVotesForInstance(instance)) == instance.classValue()) {
                this.lambdaSc.set(i, this.lambdaSc.get(i) + lambda);
                this.epsilon.set(i, this.lambdaSw.get(i) / (this.lambdaSc.get(i) + this.lambdaSw.get(i)));
                if (this.epsilon.get(i) != 1.0) {
                    lambda /= 2.0 * (1.0 - this.epsilon.get(i));
                }
            } else {
                this.lambdaSw.set(i, this.lambdaSw.get(i) + lambda);
                this.epsilon.set(i, this.lambdaSw.get(i) / (this.lambdaSc.get(i) + this.lambdaSw.get(i)));
                if (this.epsilon.get(i) != 0.0) {
                    lambda /= 2.0 * this.epsilon.get(i);
                }
            }
            if (!this.driftDetection) continue;
            double pred = Utils.maxIndex(this.ensemble.get(i).getVotesForInstance(instance));
            double errorEstimation = this.adwinEnsemble.get(i).getEstimation();
            double inputValue = pred == instance.classValue() ? 1.0 : 0.0;
            boolean resInput = this.adwinEnsemble.get(i).setInput(inputValue);
            if (!resInput || !(this.adwinEnsemble.get(i).getEstimation() > errorEstimation)) continue;
            changeDetected = true;
        }
        if (changeDetected && this.driftDetection) {
            double maxThreshold = 0.0;
            int iMax = -1;
            for (int i = 0; i < this.ensemble.size(); ++i) {
                if (!(maxThreshold < this.adwinEnsemble.get(i).getEstimation())) continue;
                maxThreshold = this.adwinEnsemble.get(i).getEstimation();
                iMax = i;
            }
            if (iMax != -1) {
                this.ensemble.get(iMax).resetLearning();
                this.adwinEnsemble.set(iMax, new ADWIN());
            }
        }
    }

    @Override
    public double[] getVotesForInstance(Instance instance) {
        Instance testInstance = instance.copy();
        DoubleVector combinedVote = new DoubleVector();
        for (int i = 0; i < this.ensemble.size(); ++i) {
            DoubleVector vote = new DoubleVector(this.ensemble.get(i).getVotesForInstance(testInstance));
            if (!(vote.sumOfValues() > 0.0)) continue;
            for (int v = 0; v < vote.numValues(); ++v) {
                vote.setValue(v, vote.getValue(v) * Math.log((1.0 - this.epsilon.get(i)) / this.epsilon.get(i)));
            }
            vote.normalize();
            combinedVote.addValues(vote);
        }
        return combinedVote.getArrayRef();
    }

    @Override
    public boolean isRandomizable() {
        return true;
    }

    @Override
    public void getModelDescription(StringBuilder arg0, int arg1) {
    }

    @Override
    protected Measurement[] getModelMeasurementsImpl() {
        return null;
    }

    protected void adjustEnsembleSize(int nClasses) {
        if (nClasses > this.nEstimators) {
            for (int i = this.nEstimators; i < nClasses; ++i) {
                this.ensemble.add(this.baseLearner.copy());
                ++this.nEstimators;
                if (this.driftDetection) {
                    this.adwinEnsemble.add(new ADWIN());
                }
                this.lambdaSc.add(0.0);
                this.lambdaPos.add(0.0);
                this.lambdaNeg.add(0.0);
                this.lambdaSw.add(0.0);
                this.epsilon.add(0.0);
            }
        }
    }

    @Override
    public ImmutableCapabilities defineImmutableCapabilities() {
        if (this.getClass() == OnlineRUSBoost.class) {
            return new ImmutableCapabilities(Capability.VIEW_STANDARD, Capability.VIEW_LITE);
        }
        return new ImmutableCapabilities(Capability.VIEW_STANDARD);
    }
}

