/*
 * Decompiled with CFR 0.152.
 */
package moa.classifiers;

import java.util.Arrays;
import java.util.Comparator;
import java.util.Random;
import moa.classifiers.AbstractClassifier;
import moa.classifiers.Classifier;
import moa.core.DoubleVector;
import moa.core.Measurement;
import moa.core.ObjectRepository;
import moa.options.ClassOption;
import moa.options.FloatOption;
import moa.options.IntOption;
import moa.tasks.TaskMonitor;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Utils;

public class AccuracyWeightedEnsemble
extends AbstractClassifier {
    private static final long serialVersionUID = 1L;
    protected static Comparator<double[]> weightComparator = new ClassifierWeightComparator();
    public ClassOption learnerOption = new ClassOption("learner", 'l', "Classifier to train.", Classifier.class, "HoeffdingTreeNB -e 1000 -g 100 -c 0.01");
    public FloatOption memberCountOption = new FloatOption("memberCount", 'n', "The maximum number of classifier in an ensemble.", 15.0, 1.0, 2.147483647E9);
    public FloatOption storedCountOption = new FloatOption("storedCount", 'r', "The maximum number of classifiers to store and choose from when creating an ensemble.", 30.0, 1.0, 2.147483647E9);
    public IntOption chunkSizeOption = new IntOption("chunkSize", 'c', "The chunk size used for classifier creation and evaluation.", 500, 1, Integer.MAX_VALUE);
    public IntOption numFoldsOption = new IntOption("numFolds", 'f', "Number of cross-validation folds for candidate classifier testing.", 10, 1, Integer.MAX_VALUE);
    protected long[] classDistributions;
    protected Classifier[] ensemble;
    protected Classifier[] storedLearners;
    protected double[] ensembleWeights;
    protected double[][] storedWeights;
    protected int processedInstances;
    protected int chunkSize;
    protected int numFolds;
    protected int maxMemberCount;
    protected int maxStoredCount;
    protected Classifier candidateClassifier;
    protected Instances currentChunk;

    public void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) {
        this.maxMemberCount = (int)this.memberCountOption.getValue();
        this.maxStoredCount = (int)this.storedCountOption.getValue();
        if (this.maxMemberCount > this.maxStoredCount) {
            this.maxStoredCount = this.maxMemberCount;
        }
        this.chunkSize = this.chunkSizeOption.getValue();
        this.numFolds = this.numFoldsOption.getValue();
        this.candidateClassifier = (Classifier)this.getPreparedClassOption(this.learnerOption);
        this.candidateClassifier.resetLearning();
        super.prepareForUseImpl(monitor, repository);
    }

    public void resetLearningImpl() {
        this.currentChunk = null;
        this.classDistributions = null;
        this.processedInstances = 0;
        this.ensemble = new Classifier[0];
        this.storedLearners = new Classifier[0];
        this.candidateClassifier = (Classifier)this.getPreparedClassOption(this.learnerOption);
        this.candidateClassifier.resetLearning();
    }

    public void trainOnInstanceImpl(Instance inst) {
        this.initVariables();
        int n = (int)inst.classValue();
        this.classDistributions[n] = this.classDistributions[n] + 1L;
        this.currentChunk.add(inst);
        ++this.processedInstances;
        if (this.processedInstances % this.chunkSize == 0) {
            this.processChunk();
        }
    }

    private void initVariables() {
        if (this.currentChunk == null) {
            this.currentChunk = new Instances((Instances)this.getModelContext());
        }
        if (this.classDistributions == null) {
            this.classDistributions = new long[this.getModelContext().classAttribute().numValues()];
            for (int i = 0; i < this.classDistributions.length; ++i) {
                this.classDistributions[i] = 0L;
            }
        }
    }

    protected void processChunk() {
        int num;
        double candidateClassifierWeight = this.computeCandidateWeight(this.candidateClassifier, this.currentChunk, this.numFolds);
        for (int i = 0; i < this.storedLearners.length; ++i) {
            this.storedWeights[i][0] = this.computeWeight(this.storedLearners[(int)this.storedWeights[i][1]], this.currentChunk);
        }
        if (this.storedLearners.length < this.maxStoredCount) {
            for (num = 0; num < this.chunkSize; ++num) {
                this.candidateClassifier.trainOnInstance(this.currentChunk.instance(num));
            }
            this.addToStored(this.candidateClassifier, candidateClassifierWeight);
        } else {
            Arrays.sort(this.storedWeights, weightComparator);
            if (this.storedWeights[0][0] < candidateClassifierWeight) {
                for (num = 0; num < this.chunkSize; ++num) {
                    this.candidateClassifier.trainOnInstance(this.currentChunk.instance(num));
                }
                this.storedWeights[0][0] = candidateClassifierWeight;
                this.storedLearners[(int)this.storedWeights[0][1]] = this.candidateClassifier.copy();
            }
        }
        int ensembleSize = Math.min(this.storedLearners.length, this.maxMemberCount);
        this.ensemble = new Classifier[ensembleSize];
        this.ensembleWeights = new double[ensembleSize];
        Arrays.sort(this.storedWeights, weightComparator);
        int storeSize = this.storedLearners.length;
        for (int i = 0; i < ensembleSize; ++i) {
            this.ensembleWeights[i] = this.storedWeights[storeSize - i - 1][0];
            this.ensemble[i] = this.storedLearners[(int)this.storedWeights[storeSize - i - 1][1]];
        }
        this.classDistributions = null;
        this.currentChunk = null;
        this.candidateClassifier = (Classifier)this.getPreparedClassOption(this.learnerOption);
        this.candidateClassifier.resetLearning();
    }

    protected double computeCandidateWeight(Classifier candidate, Instances chunk, int numFolds) {
        double candidateWeight = 0.0;
        Random random = new Random(1L);
        Instances randData = new Instances(chunk);
        randData.randomize(random);
        if (randData.classAttribute().isNominal()) {
            randData.stratify(numFolds);
        }
        for (int n = 0; n < numFolds; ++n) {
            Instances train = randData.trainCV(numFolds, n, random);
            Instances test = randData.testCV(numFolds, n);
            Classifier learner = candidate.copy();
            for (int num = 0; num < train.numInstances(); ++num) {
                learner.trainOnInstance(train.instance(num));
            }
            candidateWeight += this.computeWeight(learner, test);
        }
        double resultWeight = candidateWeight / (double)numFolds;
        if (Double.isInfinite(resultWeight)) {
            return Double.MAX_VALUE;
        }
        return resultWeight;
    }

    protected double computeWeight(Classifier learner, Instances chunk) {
        double mse_i = 0.0;
        double mse_r = 0.0;
        for (int i = 0; i < chunk.numInstances(); ++i) {
            try {
                double voteSum = 0.0;
                for (double element : learner.getVotesForInstance(chunk.instance(i))) {
                    voteSum += element;
                }
                if (voteSum > 0.0) {
                    double f_ci = learner.getVotesForInstance(chunk.instance(i))[(int)chunk.instance(i).classValue()] / voteSum;
                    mse_i += (1.0 - f_ci) * (1.0 - f_ci);
                    continue;
                }
                mse_i += 1.0;
                continue;
            }
            catch (Exception e) {
                mse_i += 1.0;
            }
        }
        mse_r = this.computeMseR();
        return Math.max(mse_r - (mse_i /= (double)this.chunkSize), 0.0);
    }

    protected double computeMseR() {
        double mse_r = 0.0;
        for (int i = 0; i < this.classDistributions.length; ++i) {
            double p_c = (double)this.classDistributions[i] / (double)this.chunkSize;
            mse_r += p_c * ((1.0 - p_c) * (1.0 - p_c));
        }
        return mse_r;
    }

    public double[] getVotesForInstance(Instance inst) {
        DoubleVector combinedVote = new DoubleVector();
        if (this.trainingWeightSeenByModel > 0.0) {
            for (int i = 0; i < this.ensemble.length; ++i) {
                DoubleVector vote;
                if (!(this.ensembleWeights[i] > 0.0) || !((vote = new DoubleVector(this.ensemble[i].getVotesForInstance(inst))).sumOfValues() > 0.0)) continue;
                vote.normalize();
                vote.scaleValues(this.ensembleWeights[i] / (1.0 * (double)this.ensemble.length + 1.0));
                combinedVote.addValues(vote);
            }
        }
        combinedVote.normalize();
        return combinedVote.getArrayRef();
    }

    public void getModelDescription(StringBuilder out, int indent) {
    }

    protected Measurement[] getModelMeasurementsImpl() {
        Measurement[] measurements = new Measurement[this.maxStoredCount];
        for (int m = 0; m < this.maxMemberCount; ++m) {
            measurements[m] = new Measurement("Member weight " + (m + 1), -1.0);
        }
        for (int s = this.maxMemberCount; s < this.maxStoredCount; ++s) {
            measurements[s] = new Measurement("Stored member weight " + (s + 1), -1.0);
        }
        if (this.storedWeights != null) {
            int storeSize = this.storedWeights.length;
            for (int i = 0; i < storeSize; ++i) {
                measurements[i] = i < this.ensemble.length ? new Measurement("Member weight " + (i + 1), this.storedWeights[storeSize - i - 1][0]) : new Measurement("Stored member weight " + (i + 1), this.storedWeights[storeSize - i - 1][0]);
            }
        }
        return measurements;
    }

    public boolean isRandomizable() {
        return false;
    }

    public Classifier[] getSubClassifiers() {
        return (Classifier[])this.ensemble.clone();
    }

    protected Classifier addToStored(Classifier newClassifier, double newClassifiersWeight) {
        Classifier addedClassifier = null;
        Classifier[] newStored = new Classifier[this.storedLearners.length + 1];
        double[][] newStoredWeights = new double[newStored.length][2];
        for (int i = 0; i < newStored.length; ++i) {
            if (i < this.storedLearners.length) {
                newStored[i] = this.storedLearners[i];
                newStoredWeights[i][0] = this.storedWeights[i][0];
                newStoredWeights[i][1] = this.storedWeights[i][1];
                continue;
            }
            newStored[i] = addedClassifier = newClassifier.copy();
            newStoredWeights[i][0] = newClassifiersWeight;
            newStoredWeights[i][1] = i;
        }
        this.storedLearners = newStored;
        this.storedWeights = newStoredWeights;
        return addedClassifier;
    }

    protected int removePoorestModelBytes() {
        int poorestIndex = Utils.minIndex((double[])this.ensembleWeights);
        int byteSize = this.ensemble[poorestIndex].measureByteSize();
        this.discardModel(poorestIndex);
        return byteSize;
    }

    protected void discardModel(int index) {
        Classifier[] newEnsemble = new Classifier[this.ensemble.length - 1];
        double[] newEnsembleWeights = new double[newEnsemble.length];
        int oldPos = 0;
        for (int i = 0; i < newEnsemble.length; ++i) {
            if (oldPos == index) {
                ++oldPos;
            }
            newEnsemble[i] = this.ensemble[oldPos];
            newEnsembleWeights[i] = this.ensembleWeights[oldPos];
            ++oldPos;
        }
        this.ensemble = newEnsemble;
        this.ensembleWeights = newEnsembleWeights;
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    private static final class ClassifierWeightComparator
    implements Comparator<double[]> {
        private ClassifierWeightComparator() {
        }

        @Override
        public int compare(double[] o1, double[] o2) {
            if (o1[0] > o2[0]) {
                return 1;
            }
            if (o1[0] < o2[0]) {
                return -1;
            }
            return 0;
        }
    }
}

