/*
 * Decompiled with CFR 0.152.
 */
package moa.classifiers.meta.imbalanced;

import com.github.javacliparser.FlagOption;
import com.github.javacliparser.IntOption;
import com.yahoo.labs.samoa.instances.Instance;
import java.util.ArrayList;
import java.util.Random;
import moa.capabilities.CapabilitiesHandler;
import moa.capabilities.Capability;
import moa.capabilities.ImmutableCapabilities;
import moa.classifiers.AbstractClassifier;
import moa.classifiers.Classifier;
import moa.classifiers.MultiClassClassifier;
import moa.classifiers.core.driftdetection.ADWIN;
import moa.core.DoubleVector;
import moa.core.Measurement;
import moa.core.MiscUtils;
import moa.core.Utils;
import moa.options.ClassOption;

public class OnlineUnderOverBagging
extends AbstractClassifier
implements MultiClassClassifier,
CapabilitiesHandler {
    private static final long serialVersionUID = 1L;
    public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l', "Classifier to train.", Classifier.class, "meta.AdaptiveRandomForest");
    public IntOption ensembleSizeOption = new IntOption("ensembleSize", 's', "The size of the ensemble.", 10, 1, Integer.MAX_VALUE);
    public IntOption samplingRateOption = new IntOption("samplingRate", 'i', "The sampling rate of the positive instances.", 2, 1, 10);
    public FlagOption disableDriftDetectionOption = new FlagOption("disableDriftDetection", 'd', "Should use ADWIN as drift detector?");
    protected Classifier baseLearner;
    protected int nEstimators;
    protected int samplingRate;
    protected boolean driftDetection;
    protected ArrayList<Classifier> ensemble;
    protected ArrayList<ADWIN> adwinEnsemble;

    @Override
    public String getPurposeString() {
        return "OnlineAdaC2 is the adaptation of the ensemble learner to data streams from B. Wang and J. Pineau";
    }

    @Override
    public void resetLearningImpl() {
        this.baseLearner = (Classifier)this.getPreparedClassOption(this.baseLearnerOption);
        this.baseLearner.resetLearning();
        this.nEstimators = this.ensembleSizeOption.getValue();
        this.samplingRate = this.samplingRateOption.getValue();
        this.driftDetection = !this.disableDriftDetectionOption.isSet();
        this.ensemble = new ArrayList();
        if (this.driftDetection) {
            this.adwinEnsemble = new ArrayList();
        }
        for (int i = 0; i < this.nEstimators; ++i) {
            this.ensemble.add(this.baseLearner.copy());
            if (!this.driftDetection) continue;
            this.adwinEnsemble.add(new ADWIN());
        }
        this.classifierRandom = new Random(this.randomSeed);
    }

    @Override
    public void trainOnInstanceImpl(Instance instance) {
        if (this.ensemble.isEmpty()) {
            this.resetLearningImpl();
        }
        this.adjustEnsembleSize(instance.numClasses());
        boolean changeDetected = false;
        double lambda = 0.0;
        for (int i = 0; i < this.ensemble.size(); ++i) {
            double a = (double)(i + 1) / (double)this.nEstimators;
            lambda = instance.classValue() == 1.0 ? a * (double)this.samplingRate : a;
            double k = MiscUtils.poisson(lambda, this.classifierRandom);
            if (k > 0.0) {
                int b = 0;
                while ((double)b < k) {
                    this.ensemble.get(i).trainOnInstance(instance);
                    ++b;
                }
            }
            if (!this.driftDetection) continue;
            double pred = Utils.maxIndex(this.ensemble.get(i).getVotesForInstance(instance));
            double errorEstimation = this.adwinEnsemble.get(i).getEstimation();
            double inputValue = pred == instance.classValue() ? 1.0 : 0.0;
            boolean resInput = this.adwinEnsemble.get(i).setInput(inputValue);
            if (!resInput || !(this.adwinEnsemble.get(i).getEstimation() > errorEstimation)) continue;
            changeDetected = true;
        }
        if (changeDetected && this.driftDetection) {
            double maxThreshold = 0.0;
            int iMax = -1;
            for (int i = 0; i < this.ensemble.size(); ++i) {
                if (!(maxThreshold < this.adwinEnsemble.get(i).getEstimation())) continue;
                maxThreshold = this.adwinEnsemble.get(i).getEstimation();
                iMax = i;
            }
            if (iMax != -1) {
                this.ensemble.get(iMax).resetLearning();
                this.adwinEnsemble.set(iMax, new ADWIN());
            }
        }
    }

    @Override
    public double[] getVotesForInstance(Instance instance) {
        Instance testInstance = instance.copy();
        DoubleVector combinedVote = new DoubleVector();
        for (int i = 0; i < this.ensemble.size(); ++i) {
            DoubleVector vote = new DoubleVector(this.ensemble.get(i).getVotesForInstance(testInstance));
            if (!(vote.sumOfValues() > 0.0)) continue;
            vote.normalize();
            combinedVote.addValues(vote);
        }
        return combinedVote.getArrayRef();
    }

    @Override
    public boolean isRandomizable() {
        return true;
    }

    @Override
    public void getModelDescription(StringBuilder arg0, int arg1) {
    }

    @Override
    protected Measurement[] getModelMeasurementsImpl() {
        return null;
    }

    protected void adjustEnsembleSize(int nClasses) {
        if (nClasses > this.nEstimators) {
            for (int i = this.nEstimators; i < nClasses; ++i) {
                this.ensemble.add(this.baseLearner.copy());
                ++this.nEstimators;
                if (!this.driftDetection) continue;
                this.adwinEnsemble.add(new ADWIN());
            }
        }
    }

    @Override
    public ImmutableCapabilities defineImmutableCapabilities() {
        if (this.getClass() == OnlineUnderOverBagging.class) {
            return new ImmutableCapabilities(Capability.VIEW_STANDARD, Capability.VIEW_LITE);
        }
        return new ImmutableCapabilities(Capability.VIEW_STANDARD);
    }
}

