/*
 * Decompiled with CFR 0.152.
 */
package moa.classifiers;

import java.util.LinkedList;
import moa.AbstractMOAObject;
import moa.classifiers.AbstractClassifier;
import moa.classifiers.Classifier;
import moa.classifiers.WEKAClassifier;
import moa.core.Measurement;
import moa.options.ClassOption;
import moa.options.MultiChoiceOption;
import weka.core.Instance;
import weka.core.Utils;

public class SingleClassifierDrift
extends AbstractClassifier {
    private static final long serialVersionUID = 1L;
    public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l', "Classifier to train.", Classifier.class, "NaiveBayes");
    public MultiChoiceOption driftDetectionMethodOption = new MultiChoiceOption("driftDetectionMethod", 'd', "Drift detection method to use.", new String[]{"DDM", "EDDM"}, new String[]{"DDM: Joao Gama Drift Detection Method", "EDDM: Early Drift Detection Method"}, 0);
    protected Classifier classifier;
    protected Classifier newclassifier;
    protected DriftDetectionMethod driftDetectionMethod;
    protected boolean newClassifierReset;
    protected int ddmLevel;
    protected int changeDetected = 0;
    protected int warningDetected = 0;

    public boolean isWarningDetected() {
        return this.ddmLevel == 1;
    }

    public boolean isChangeDetected() {
        return this.ddmLevel == 2;
    }

    public void resetLearningImpl() {
        this.classifier = (Classifier)this.getPreparedClassOption(this.baseLearnerOption);
        this.newclassifier = this.classifier.copy();
        this.classifier.resetLearning();
        this.newclassifier.resetLearning();
        this.driftDetectionMethod = this.newDriftDetectionMethod();
        this.newClassifierReset = false;
    }

    public void trainOnInstanceImpl(Instance inst) {
        int trueClass = (int)inst.classValue();
        boolean prediction = Utils.maxIndex((double[])this.classifier.getVotesForInstance(inst)) == trueClass;
        this.ddmLevel = this.driftDetectionMethod.computeNextVal(prediction);
        switch (this.ddmLevel) {
            case 1: {
                ++this.warningDetected;
                if (this.newClassifierReset) {
                    this.newclassifier.resetLearning();
                    this.newClassifierReset = false;
                }
                this.newclassifier.trainOnInstance(inst);
                break;
            }
            case 2: {
                ++this.changeDetected;
                this.classifier = null;
                this.classifier = this.newclassifier;
                if (this.classifier instanceof WEKAClassifier) {
                    ((WEKAClassifier)this.classifier).buildClassifier();
                }
                this.newclassifier = ((Classifier)this.getPreparedClassOption(this.baseLearnerOption)).copy();
                this.newclassifier.resetLearning();
                break;
            }
            case 0: {
                this.newClassifierReset = true;
                break;
            }
        }
        this.classifier.trainOnInstance(inst);
    }

    public double[] getVotesForInstance(Instance inst) {
        return this.classifier.getVotesForInstance(inst);
    }

    public boolean isRandomizable() {
        return true;
    }

    public void getModelDescription(StringBuilder out, int indent) {
        ((AbstractClassifier)this.classifier).getModelDescription(out, indent);
    }

    protected Measurement[] getModelMeasurementsImpl() {
        LinkedList<Measurement> measurementList = new LinkedList<Measurement>();
        measurementList.add(new Measurement("Change detected", this.changeDetected));
        measurementList.add(new Measurement("Warning detected", this.warningDetected));
        Measurement[] modelMeasurements = ((AbstractClassifier)this.classifier).getModelMeasurementsImpl();
        if (modelMeasurements != null) {
            for (Measurement measurement : modelMeasurements) {
                measurementList.add(measurement);
            }
        }
        this.changeDetected = 0;
        this.warningDetected = 0;
        return measurementList.toArray(new Measurement[measurementList.size()]);
    }

    protected DriftDetectionMethod newDriftDetectionMethod() {
        switch (this.driftDetectionMethodOption.getChosenIndex()) {
            case 0: {
                return new JGamaMethod();
            }
            case 1: {
                return new EDDM();
            }
        }
        return new DriftDetectionMethod();
    }

    public class EDDM
    extends DriftDetectionMethod {
        private static final double FDDM_OUTCONTROL = 0.9;
        private static final double FDDM_WARNING = 0.95;
        private static final double FDDM_MINNUMINSTANCES = 30.0;
        private double m_numErrors;
        private int m_minNumErrors;
        private int m_n;
        private int m_d;
        private int m_lastd;
        private double m_mean;
        private double m_stdTemp;
        private double m_m2smax;
        private int m_lastLevel;

        public EDDM() {
            this.m_minNumErrors = 30;
            this.initialize();
        }

        private void initialize() {
            this.m_n = 1;
            this.m_numErrors = 0.0;
            this.m_d = 0;
            this.m_lastd = 0;
            this.m_mean = 0.0;
            this.m_stdTemp = 0.0;
            this.m_m2smax = 0.0;
            this.m_lastLevel = 0;
        }

        public int computeNextVal(boolean prediction) {
            ++this.m_n;
            if (!prediction) {
                this.m_numErrors += 1.0;
                this.m_lastd = this.m_d;
                this.m_d = this.m_n - 1;
                int distance = this.m_d - this.m_lastd;
                double oldmean = this.m_mean;
                this.m_mean += ((double)distance - this.m_mean) / this.m_numErrors;
                this.m_stdTemp += ((double)distance - this.m_mean) * ((double)distance - oldmean);
                double std = Math.sqrt(this.m_stdTemp / this.m_numErrors);
                double m2s = this.m_mean + 2.0 * std;
                if (m2s > this.m_m2smax) {
                    if ((double)this.m_n > 30.0) {
                        this.m_m2smax = m2s;
                    }
                    this.m_lastLevel = 0;
                } else {
                    double p = m2s / this.m_m2smax;
                    if ((double)this.m_n > 30.0 && this.m_numErrors > (double)this.m_minNumErrors && p < 0.9) {
                        this.initialize();
                        return 2;
                    }
                    if ((double)this.m_n > 30.0 && this.m_numErrors > (double)this.m_minNumErrors && p < 0.95) {
                        this.m_lastLevel = 1;
                        return 1;
                    }
                    this.m_lastLevel = 0;
                    return 0;
                }
            }
            return this.m_lastLevel;
        }
    }

    public class JGamaMethod
    extends DriftDetectionMethod {
        private static final int JGAMAMETHOD_MINNUMINST = 30;
        private int m_n;
        private double m_p;
        private double m_s;
        private double m_psmin;
        private double m_pmin;
        private double m_smin;

        public JGamaMethod() {
            this.initialize();
        }

        private void initialize() {
            this.m_n = 1;
            this.m_p = 1.0;
            this.m_s = 0.0;
            this.m_psmin = Double.MAX_VALUE;
            this.m_pmin = Double.MAX_VALUE;
            this.m_smin = Double.MAX_VALUE;
        }

        public int computeNextVal(boolean prediction) {
            this.m_p = !prediction ? (this.m_p += (1.0 - this.m_p) / (double)this.m_n) : (this.m_p -= this.m_p / (double)this.m_n);
            this.m_s = Math.sqrt(this.m_p * (1.0 - this.m_p) / (double)this.m_n);
            ++this.m_n;
            if (this.m_n < 30) {
                return 0;
            }
            if (this.m_p + this.m_s <= this.m_psmin) {
                this.m_pmin = this.m_p;
                this.m_smin = this.m_s;
                this.m_psmin = this.m_p + this.m_s;
            }
            if (this.m_n > 30 && this.m_p + this.m_s > this.m_pmin + 3.0 * this.m_smin) {
                this.initialize();
                return 2;
            }
            if (this.m_p + this.m_s > this.m_pmin + 2.0 * this.m_smin) {
                return 1;
            }
            return 0;
        }
    }

    public class DriftDetectionMethod
    extends AbstractMOAObject {
        private static final long serialVersionUID = 1L;
        public static final int DDM_INCONTROL_LEVEL = 0;
        public static final int DDM_WARNING_LEVEL = 1;
        public static final int DDM_OUTCONTROL_LEVEL = 2;

        public int computeNextVal(boolean prediction) {
            return 0;
        }

        public void getDescription(StringBuilder sb, int indent) {
        }
    }
}

