/*
 * Decompiled with CFR 0.152.
 */
package moa.classifiers.meta;

import com.github.javacliparser.FlagOption;
import com.github.javacliparser.FloatOption;
import com.github.javacliparser.IntOption;
import com.github.javacliparser.MultiChoiceOption;
import com.yahoo.labs.samoa.instances.Attribute;
import com.yahoo.labs.samoa.instances.DenseInstance;
import com.yahoo.labs.samoa.instances.Instance;
import com.yahoo.labs.samoa.instances.Instances;
import java.util.ArrayList;
import java.util.Random;
import moa.capabilities.CapabilitiesHandler;
import moa.capabilities.Capability;
import moa.capabilities.ImmutableCapabilities;
import moa.classifiers.AbstractClassifier;
import moa.classifiers.Classifier;
import moa.classifiers.MultiClassClassifier;
import moa.classifiers.core.driftdetection.ChangeDetector;
import moa.core.DoubleVector;
import moa.core.Example;
import moa.core.InstanceExample;
import moa.core.Measurement;
import moa.core.MiscUtils;
import moa.evaluation.BasicClassificationPerformanceEvaluator;
import moa.options.ClassOption;

public class StreamingRandomPatches
extends AbstractClassifier
implements MultiClassClassifier,
CapabilitiesHandler {
    private static final long serialVersionUID = 1L;
    public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l', "Classifier to train on instances.", Classifier.class, "trees.HoeffdingTree -g 50 -c 0.01");
    public IntOption ensembleSizeOption = new IntOption("ensembleSize", 's', "The number of models.", 100, 1, Integer.MAX_VALUE);
    public MultiChoiceOption subspaceModeOption = new MultiChoiceOption("subspaceMode", 'o', "Defines how m, defined by mFeaturesPerTreeSize, is interpreted. M represents the total number of features.", new String[]{"Specified m (integer value)", "sqrt(M)+1", "M-(sqrt(M)+1)", "Percentage (M * (m / 100))"}, new String[]{"SpecifiedM", "SqrtM1", "MSqrtM1", "Percentage"}, 3);
    public IntOption subspaceSizeOption = new IntOption("subspaceSize", 'm', "# attributes per subset for each classifier. Negative values = totalAttributes - #attributes", 60, Integer.MIN_VALUE, Integer.MAX_VALUE);
    public MultiChoiceOption trainingMethodOption = new MultiChoiceOption("trainingMethod", 't', "The training method to use: Random Patches, Random Subspaces or Bagging.", new String[]{"Random Subspaces", "Resampling (bagging)", "Random Patches"}, new String[]{"RandomSubspaces", "Resampling", "RandomPatches"}, 2);
    public FloatOption lambdaOption = new FloatOption("lambda", 'a', "The lambda parameter for bagging.", 6.0, 1.0, 3.4028234663852886E38);
    public ClassOption driftDetectionMethodOption = new ClassOption("driftDetectionMethod", 'x', "Change detector for drifts and its parameters", ChangeDetector.class, "ADWINChangeDetector -a 1.0E-5");
    public ClassOption warningDetectionMethodOption = new ClassOption("warningDetectionMethod", 'p', "Change detector for warnings (start training bkg learner)", ChangeDetector.class, "ADWINChangeDetector -a 1.0E-4");
    public FlagOption disableWeightedVote = new FlagOption("disableWeightedVote", 'w', "Should use weighted voting?");
    public FlagOption disableDriftDetectionOption = new FlagOption("disableDriftDetection", 'u', "Should use drift detection? If disabled, then the bkg learner is also disabled.");
    public FlagOption disableBackgroundLearnerOption = new FlagOption("disableBackgroundLearner", 'q', "Should use bkg learner? If disabled, then trees are reset immediately.");
    public static final int TRAIN_RANDOM_SUBSPACES = 0;
    public static final int TRAIN_RESAMPLING = 1;
    public static final int TRAIN_RANDOM_PATCHES = 2;
    protected static final int FEATURES_M = 0;
    protected static final int FEATURES_SQRT = 1;
    protected static final int FEATURES_SQRT_INV = 2;
    protected static final int FEATURES_PERCENT = 3;
    protected StreamingRandomPatchesClassifier[] ensemble;
    protected long instancesSeen;
    protected ArrayList<ArrayList<Integer>> subspaces;

    @Override
    public void resetLearningImpl() {
        this.instancesSeen = 0L;
    }

    @Override
    public void trainOnInstanceImpl(Instance instance) {
        ++this.instancesSeen;
        if (this.ensemble == null) {
            this.initEnsemble(instance);
        }
        for (int i = 0; i < this.ensemble.length; ++i) {
            double[] rawVote = this.ensemble[i].getVotesForInstance(instance);
            DoubleVector vote = new DoubleVector(rawVote);
            InstanceExample example = new InstanceExample(instance);
            this.ensemble[i].evaluator.addResult((Example<Instance>)example, vote.getArrayRef());
            if (this.trainingMethodOption.getChosenIndex() == 0) {
                this.ensemble[i].trainOnInstance(instance, 1.0, this.instancesSeen, this.classifierRandom);
                continue;
            }
            int k = MiscUtils.poisson(this.lambdaOption.getValue(), this.classifierRandom);
            if (k <= 0) continue;
            double weight = k;
            this.ensemble[i].trainOnInstance(instance, weight, this.instancesSeen, this.classifierRandom);
        }
    }

    @Override
    public double[] getVotesForInstance(Instance instance) {
        Instance testInstance = instance.copy();
        testInstance.setMissing(instance.classAttribute());
        testInstance.setClassValue(0.0);
        if (this.ensemble == null) {
            this.initEnsemble(testInstance);
        }
        DoubleVector combinedVote = new DoubleVector();
        for (int i = 0; i < this.ensemble.length; ++i) {
            DoubleVector vote = new DoubleVector(this.ensemble[i].getVotesForInstance(testInstance));
            if (!(vote.sumOfValues() > 0.0)) continue;
            vote.normalize();
            double acc = this.ensemble[i].evaluator.getPerformanceMeasurements()[1].getValue();
            if (!this.disableWeightedVote.isSet() && acc > 0.0) {
                for (int v = 0; v < vote.numValues(); ++v) {
                    vote.setValue(v, vote.getValue(v) * acc);
                }
            }
            combinedVote.addValues(vote);
        }
        return combinedVote.getArrayRef();
    }

    @Override
    public boolean isRandomizable() {
        return true;
    }

    @Override
    public void getModelDescription(StringBuilder arg0, int arg1) {
    }

    @Override
    protected Measurement[] getModelMeasurementsImpl() {
        return null;
    }

    protected void initEnsemble(Instance instance) {
        int ensembleSize = this.ensembleSizeOption.getValue();
        this.ensemble = new StreamingRandomPatchesClassifier[ensembleSize];
        BasicClassificationPerformanceEvaluator classificationEvaluator = new BasicClassificationPerformanceEvaluator();
        int k = this.subspaceSizeOption.getValue();
        if (this.trainingMethodOption.getChosenIndex() != 1) {
            int n = instance.numAttributes() - 1;
            switch (this.subspaceModeOption.getChosenIndex()) {
                case 1: {
                    k = (int)Math.round(Math.sqrt(n)) + 1;
                    break;
                }
                case 2: {
                    k = n - (int)Math.round(Math.sqrt(n) + 1.0);
                    break;
                }
                case 3: {
                    double percent = k < 0 ? (double)(100 + k) / 100.0 : (double)k / 100.0;
                    k = (int)Math.round((double)n * percent);
                    if (Math.round((double)n * percent) >= 2L) break;
                    k = (int)Math.round((double)n * percent) + 1;
                }
            }
            if (k < 0) {
                k = n + k;
            }
            if (this.trainingMethodOption.getChosenIndex() == 0 || this.trainingMethodOption.getChosenIndex() == 2) {
                if (k != 0 && k < n) {
                    if (n <= 20 || k < 2) {
                        if (k == 1 && instance.numAttributes() > 2) {
                            k = 2;
                        }
                        this.subspaces = StreamingRandomPatches.allKCombinations(k, n);
                        int i = 0;
                        while (this.subspaces.size() < this.ensemble.length) {
                            i = i == this.subspaces.size() ? 0 : i;
                            ArrayList copiedSubspace = new ArrayList(this.subspaces.get(i));
                            this.subspaces.add(copiedSubspace);
                            ++i;
                        }
                    } else {
                        this.subspaces = StreamingRandomPatches.localRandomKCombinations(k, n, this.ensembleSizeOption.getValue(), this.classifierRandom);
                    }
                } else {
                    this.trainingMethodOption.setChosenIndex(1);
                }
            }
        }
        Classifier baseLearner = (Classifier)this.getPreparedClassOption(this.baseLearnerOption);
        baseLearner.resetLearning();
        block10: for (int i = 0; i < ensembleSize; ++i) {
            switch (this.trainingMethodOption.getChosenIndex()) {
                case 1: {
                    this.ensemble[i] = new StreamingRandomPatchesClassifier(i, baseLearner.copy(), (BasicClassificationPerformanceEvaluator)classificationEvaluator.copy(), this.instancesSeen, this.disableBackgroundLearnerOption.isSet(), this.disableDriftDetectionOption.isSet(), this.driftDetectionMethodOption, this.warningDetectionMethodOption, false);
                    continue block10;
                }
                case 0: 
                case 2: {
                    int selectedValue = this.classifierRandom.nextInt(this.subspaces.size());
                    ArrayList<Integer> subsetOfFeatures = this.subspaces.get(selectedValue);
                    subsetOfFeatures.add(instance.classIndex());
                    this.ensemble[i] = new StreamingRandomPatchesClassifier(i, baseLearner.copy(), (BasicClassificationPerformanceEvaluator)classificationEvaluator.copy(), this.instancesSeen, this.disableBackgroundLearnerOption.isSet(), this.disableDriftDetectionOption.isSet(), this.driftDetectionMethodOption, this.warningDetectionMethodOption, subsetOfFeatures, instance, false);
                    this.subspaces.remove(selectedValue);
                }
            }
        }
    }

    @Override
    public ImmutableCapabilities defineImmutableCapabilities() {
        if (this.getClass() == StreamingRandomPatches.class) {
            return new ImmutableCapabilities(Capability.VIEW_STANDARD, Capability.VIEW_LITE);
        }
        return new ImmutableCapabilities(Capability.VIEW_STANDARD);
    }

    public Classifier[] getSublearners() {
        Classifier[] baseModels = new Classifier[this.ensemble.length];
        for (int i = 0; i < baseModels.length; ++i) {
            baseModels[i] = this.ensemble[i].classifier;
        }
        return baseModels;
    }

    private static ArrayList<ArrayList<Integer>> localRandomKCombinations(int k, int length, int nCombinations, Random random) {
        ArrayList<ArrayList<Integer>> combinations = new ArrayList<ArrayList<Integer>>();
        for (int i = 0; i < nCombinations; ++i) {
            int j;
            ArrayList<Integer> combination = new ArrayList<Integer>();
            for (j = 0; j < length; ++j) {
                combination.add(j);
            }
            for (j = 0; j < length - k; ++j) {
                combination.remove(random.nextInt(combination.size()));
            }
            combinations.add(combination);
        }
        return combinations;
    }

    private static void allKCombinationsInner(int offset, int k, ArrayList<Integer> combination, long originalSize, ArrayList<ArrayList<Integer>> combinations) {
        if (k == 0) {
            combinations.add(new ArrayList<Integer>(combination));
            return;
        }
        int i = offset;
        while ((long)i <= originalSize - (long)k) {
            combination.add(i);
            StreamingRandomPatches.allKCombinationsInner(i + 1, k - 1, combination, originalSize, combinations);
            combination.remove(combination.size() - 1);
            ++i;
        }
    }

    private static ArrayList<ArrayList<Integer>> allKCombinations(int k, int length) {
        ArrayList<ArrayList<Integer>> combinations = new ArrayList<ArrayList<Integer>>();
        ArrayList<Integer> combination = new ArrayList<Integer>();
        StreamingRandomPatches.allKCombinationsInner(0, k, combination, length, combinations);
        return combinations;
    }

    protected class StreamingRandomPatchesClassifier {
        public int indexOriginal;
        public long createdOn;
        public Classifier classifier;
        public Instances subset;
        public int[] featureIndexes;
        public boolean disableBkgLearner;
        public boolean disableDriftDetector;
        protected ChangeDetector driftDetectionMethod;
        protected ChangeDetector warningDetectionMethod;
        protected ClassOption driftOption;
        protected ClassOption warningOption;
        public StreamingRandomPatchesClassifier bkgLearner;
        public boolean isBackgroundLearner;
        public BasicClassificationPerformanceEvaluator evaluator;
        public int numberOfDriftsDetected;
        public int numberOfWarningsDetected;
        public int numberOfDriftsInduced;
        public int numberOfWarningsInduced;

        private void init(int indexOriginal, Classifier instantiatedClassifier, BasicClassificationPerformanceEvaluator evaluatorInstantiated, long instancesSeen, boolean disableBkgLearner, boolean disableDriftDetector, ClassOption driftOption, ClassOption warningOption, boolean isBackgroundLearner) {
            this.indexOriginal = indexOriginal;
            this.createdOn = instancesSeen;
            this.classifier = instantiatedClassifier;
            this.evaluator = evaluatorInstantiated;
            this.disableBkgLearner = disableBkgLearner;
            this.disableDriftDetector = disableDriftDetector;
            if (!this.disableDriftDetector) {
                this.driftOption = driftOption;
                this.driftDetectionMethod = ((ChangeDetector)StreamingRandomPatches.this.getPreparedClassOption(driftOption)).copy();
            }
            if (!this.disableBkgLearner) {
                this.warningOption = warningOption;
                this.warningDetectionMethod = ((ChangeDetector)StreamingRandomPatches.this.getPreparedClassOption(warningOption)).copy();
            }
            this.numberOfDriftsInduced = 0;
            this.numberOfDriftsDetected = 0;
            this.numberOfWarningsInduced = 0;
            this.numberOfWarningsDetected = 0;
            this.isBackgroundLearner = isBackgroundLearner;
        }

        public StreamingRandomPatchesClassifier(int indexOriginal, Classifier instantiatedClassifier, BasicClassificationPerformanceEvaluator evaluatorInstantiated, long instancesSeen, boolean disableBkgLearner, boolean disableDriftDetector, ClassOption driftOption, ClassOption warningOption, boolean isBackgroundLearner) {
            this.init(indexOriginal, instantiatedClassifier, evaluatorInstantiated, instancesSeen, disableBkgLearner, disableDriftDetector, driftOption, warningOption, isBackgroundLearner);
            this.featureIndexes = null;
            this.subset = null;
        }

        public StreamingRandomPatchesClassifier(int indexOriginal, Classifier instantiatedClassifier, BasicClassificationPerformanceEvaluator evaluatorInstantiated, long instancesSeen, boolean disableBkgLearner, boolean disableDriftDetector, ClassOption driftOption, ClassOption warningOption, ArrayList<Integer> featuresIndexes, Instance instance, boolean isBackgroundLearner) {
            this.init(indexOriginal, instantiatedClassifier, evaluatorInstantiated, instancesSeen, disableBkgLearner, disableDriftDetector, driftOption, warningOption, isBackgroundLearner);
            this.featureIndexes = new int[featuresIndexes.size()];
            ArrayList<Attribute> attSub = new ArrayList<Attribute>();
            for (int i = 0; i < featuresIndexes.size(); ++i) {
                attSub.add(instance.attribute(featuresIndexes.get(i)));
                this.featureIndexes[i] = featuresIndexes.get(i);
            }
            this.subset = new Instances("Subsets Candidate Instances", attSub, 100);
            this.subset.setClassIndex(this.subset.numAttributes() - 1);
            this.prepareRandomSubspaceInstance(instance, 1.0);
        }

        public void prepareRandomSubspaceInstance(Instance instance, double weight) {
            while (this.subset.numInstances() > 0) {
                this.subset.delete(0);
            }
            double[] values = new double[this.subset.numAttributes()];
            for (int j = 0; j < this.subset.numAttributes(); ++j) {
                values[j] = instance.value(this.featureIndexes[j]);
            }
            values[values.length - 1] = instance.classValue();
            DenseInstance subInstance = new DenseInstance(1.0, values);
            subInstance.setWeight(weight);
            subInstance.setDataset(this.subset);
            this.subset.add(subInstance);
        }

        private ArrayList<Integer> applySubsetResetStrategy(Instance instance, Random random) {
            if (this.subset != null) {
                int j;
                ArrayList<Integer> fIndexes = new ArrayList<Integer>();
                for (j = 0; j < instance.numAttributes(); ++j) {
                    fIndexes.add(j);
                }
                fIndexes.remove(instance.classIndex());
                for (j = 0; j < instance.numAttributes() - this.featureIndexes.length; ++j) {
                    fIndexes.remove(random.nextInt(fIndexes.size()));
                }
                fIndexes.add(instance.classIndex());
                return fIndexes;
            }
            return null;
        }

        public void reset(Instance instance, long instancesSeen, Random random) {
            if (!this.disableBkgLearner && this.bkgLearner != null) {
                this.classifier = this.bkgLearner.classifier;
                this.driftDetectionMethod = this.bkgLearner.driftDetectionMethod;
                this.warningDetectionMethod = this.bkgLearner.warningDetectionMethod;
                this.evaluator = this.bkgLearner.evaluator;
                this.evaluator.reset();
                this.createdOn = this.bkgLearner.createdOn;
                this.subset = this.bkgLearner.subset;
                this.featureIndexes = this.bkgLearner.featureIndexes;
            } else {
                this.classifier.resetLearning();
                this.evaluator.reset();
                this.createdOn = instancesSeen;
                this.driftDetectionMethod = ((ChangeDetector)StreamingRandomPatches.this.getPreparedClassOption(this.driftOption)).copy();
                if (this.subset != null) {
                    ArrayList<Integer> fIndexes = this.applySubsetResetStrategy(instance, random);
                    for (int i = 0; i < fIndexes.size(); ++i) {
                        this.featureIndexes[i] = fIndexes.get(i);
                    }
                    ArrayList<Attribute> attSub = new ArrayList<Attribute>();
                    for (int i = 0; i < this.featureIndexes.length; ++i) {
                        attSub.add(instance.attribute(this.featureIndexes[i]));
                    }
                    this.subset = new Instances("Subsets Candidate Instances", attSub, 100);
                    this.subset.setClassIndex(this.subset.numAttributes() - 1);
                    this.prepareRandomSubspaceInstance(instance, 1.0);
                }
            }
        }

        public void trainOnInstance(Instance instance, double weight, long instancesSeen, Random random) {
            boolean correctlyClassifies;
            if (this.subset != null) {
                this.prepareRandomSubspaceInstance(instance, weight);
                this.classifier.trainOnInstance(this.subset.get(0));
                correctlyClassifies = this.classifier.correctlyClassifies(this.subset.get(0));
                if (this.bkgLearner != null) {
                    this.bkgLearner.trainOnInstance(instance, weight, instancesSeen, random);
                }
            } else {
                Instance weightedInstance = instance.copy();
                weightedInstance.setWeight(instance.weight() * weight);
                this.classifier.trainOnInstance(weightedInstance);
                correctlyClassifies = this.classifier.correctlyClassifies(instance);
                if (this.bkgLearner != null) {
                    this.bkgLearner.trainOnInstance(instance, weight, instancesSeen, random);
                }
            }
            if (!this.disableDriftDetector && !this.isBackgroundLearner) {
                if (!this.disableBkgLearner) {
                    this.warningDetectionMethod.input(correctlyClassifies ? 0.0 : 1.0);
                    if (this.warningDetectionMethod.getChange()) {
                        ++this.numberOfWarningsDetected;
                        this.triggerWarning(instance, instancesSeen, random);
                    }
                }
                this.driftDetectionMethod.input(correctlyClassifies ? 0.0 : 1.0);
                if (this.driftDetectionMethod.getChange()) {
                    ++this.numberOfDriftsDetected;
                    this.reset(instance, instancesSeen, random);
                }
            }
        }

        public void triggerWarning(Instance instance, long instancesSeen, Random random) {
            Classifier bkgClassifier = this.classifier.copy();
            bkgClassifier.resetLearning();
            BasicClassificationPerformanceEvaluator bkgEvaluator = (BasicClassificationPerformanceEvaluator)this.evaluator.copy();
            bkgEvaluator.reset();
            if (this.subset == null) {
                this.bkgLearner = new StreamingRandomPatchesClassifier(this.indexOriginal, bkgClassifier, bkgEvaluator, instancesSeen, this.disableBkgLearner, this.disableDriftDetector, this.driftOption, this.warningOption, true);
            } else {
                ArrayList<Integer> fIndexes = this.applySubsetResetStrategy(instance, random);
                this.bkgLearner = new StreamingRandomPatchesClassifier(this.indexOriginal, bkgClassifier, bkgEvaluator, instancesSeen, this.disableBkgLearner, this.disableDriftDetector, this.driftOption, this.warningOption, fIndexes, instance, true);
            }
            this.warningDetectionMethod = ((ChangeDetector)StreamingRandomPatches.this.getPreparedClassOption(this.warningOption)).copy();
        }

        public double[] getVotesForInstance(Instance instance) {
            if (this.subset != null) {
                this.prepareRandomSubspaceInstance(instance, 1.0);
                DoubleVector vote = new DoubleVector(this.classifier.getVotesForInstance(this.subset.get(0)));
                return vote.getArrayRef();
            }
            DoubleVector vote = new DoubleVector(this.classifier.getVotesForInstance(instance));
            return vote.getArrayRef();
        }
    }
}

