/*
 * Decompiled with CFR 0.152.
 */
package moa.classifiers.meta.imbalanced;

import com.github.javacliparser.FlagOption;
import com.github.javacliparser.FloatOption;
import com.github.javacliparser.IntOption;
import com.yahoo.labs.samoa.instances.Instance;
import com.yahoo.labs.samoa.instances.Instances;
import com.yahoo.labs.samoa.instances.SamoaToWekaInstanceConverter;
import com.yahoo.labs.samoa.instances.WekaToSamoaInstanceConverter;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.Random;
import moa.classifiers.AbstractClassifier;
import moa.classifiers.Classifier;
import moa.classifiers.MultiClassClassifier;
import moa.classifiers.core.driftdetection.ADWIN;
import moa.classifiers.lazy.neighboursearch.LinearNNSearch;
import moa.classifiers.lazy.neighboursearch.NearestNeighbourSearch;
import moa.core.Measurement;
import moa.core.Utils;
import moa.options.ClassOption;
import weka.core.Attribute;

public class CSMOTE
extends AbstractClassifier
implements MultiClassClassifier {
    private static final long serialVersionUID = 1L;
    public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l', "Classifier to train.", Classifier.class, "meta.AdaptiveRandomForest");
    public IntOption neighborsOption = new IntOption("neighbors", 'k', "Number of neighbors for SMOTE.", 5, 1, Integer.MAX_VALUE);
    public FloatOption thresholdOption = new FloatOption("threshold", 't', "Minority class samples threshold.", 0.5, 0.1, 0.5);
    public IntOption minSizeAllowedOption = new IntOption("minSizeAllowed", 'm', "Minimum number of samples in the minority class for appling SMOTE.", 100, 10, Integer.MAX_VALUE);
    public FlagOption disableDriftDetectionOption = new FlagOption("disableDriftDetection", 'd', "Should use ADWIN as drift detector?");
    protected Classifier learner;
    protected int neighbors;
    protected double threshold;
    protected int minSizeAllowed;
    protected boolean driftDetection;
    protected ADWIN adwin;
    protected ADWIN adwinDriftDetector;
    protected ArrayList<Instance> W = new ArrayList();
    protected Instances min = null;
    protected Instances maj = null;
    protected int nMinorityTotal;
    protected int nMajorityTotal;
    protected int nGeneratedMinorityTotal;
    protected int nGeneratedMajorityTotal;
    protected HashMap<Instance, Integer> instanceGenerated = new HashMap();
    protected ArrayList<Integer> alreadyUsed = new ArrayList();
    protected SamoaToWekaInstanceConverter samoaToWeka = new SamoaToWekaInstanceConverter();
    protected WekaToSamoaInstanceConverter wekaToSamoa = new WekaToSamoaInstanceConverter();
    protected int[] indexValues;

    @Override
    public String getPurposeString() {
        return "OnlineSMOTE strategy that saves the data in a sliding window and when the minority class ratio is less than a threshold it generates some synthetic new samples using SMOTE";
    }

    @Override
    public void resetLearningImpl() {
        this.learner = (Classifier)this.getPreparedClassOption(this.baseLearnerOption);
        this.neighbors = this.neighborsOption.getValue();
        this.threshold = this.thresholdOption.getValue();
        this.minSizeAllowed = this.minSizeAllowedOption.getValue();
        this.driftDetection = !this.disableDriftDetectionOption.isSet();
        this.learner.resetLearning();
        this.nMinorityTotal = 0;
        this.nMajorityTotal = 0;
        this.nGeneratedMinorityTotal = 0;
        this.nGeneratedMajorityTotal = 0;
        this.alreadyUsed.clear();
        this.instanceGenerated.clear();
        this.indexValues = null;
        this.adwin = new ADWIN();
        this.adwinDriftDetector = new ADWIN();
        this.min = null;
        this.maj = null;
        this.W.clear();
        this.classifierRandom = new Random(this.randomSeed);
    }

    @Override
    public double[] getVotesForInstance(Instance instance) {
        double[] prediction = this.learner.getVotesForInstance(instance);
        return prediction;
    }

    @Override
    public void trainOnInstanceImpl(Instance instance) {
        this.learner.trainOnInstance(instance);
        this.fillBatches(instance);
        this.adwin.setInput(instance.classValue());
        this.checkADWINWidth();
        boolean allowSMOTE = false;
        if (this.min != null && this.maj != null) {
            if (this.min.numInstances() <= this.maj.numInstances()) {
                if (this.min.numInstances() > this.minSizeAllowed) {
                    allowSMOTE = true;
                }
            } else if (this.maj.numInstances() > this.minSizeAllowed) {
                allowSMOTE = true;
            }
        }
        if (allowSMOTE) {
            while (this.threshold > this.calculateRatio()) {
                Instance newInstance = this.onlineSMOTE();
                if (newInstance == null) continue;
                this.learner.trainOnInstance(newInstance);
            }
            this.alreadyUsed.clear();
        }
        if (this.driftDetection) {
            double pred = Utils.maxIndex(this.learner.getVotesForInstance(instance));
            double errorEstimation = this.adwinDriftDetector.getEstimation();
            double inputValue = pred == instance.classValue() ? 1.0 : 0.0;
            boolean resInput = this.adwinDriftDetector.setInput(inputValue);
            if (resInput && this.adwinDriftDetector.getEstimation() > errorEstimation) {
                this.learner.resetLearning();
                this.adwinDriftDetector = new ADWIN();
            }
        }
    }

    private void fillBatches(Instance instance) {
        this.W.add(instance);
        if (instance.classValue() == 1.0) {
            if (this.maj == null) {
                this.maj = instance.dataset();
                this.maj.setClassIndex(this.maj.numAttributes() - 1);
            }
            ++this.nMajorityTotal;
            this.maj.add(instance);
        } else {
            if (this.min == null) {
                this.min = instance.dataset();
                this.min.setClassIndex(this.min.numAttributes() - 1);
            }
            ++this.nMinorityTotal;
            this.min.add(instance);
        }
    }

    private void checkADWINWidth() {
        if (this.adwin.getChange()) {
            int newWidth = this.adwin.getWidth();
            int windowSize = this.W.size();
            int diff = windowSize - newWidth;
            for (int i = 0; i < diff; ++i) {
                Instance instanceRemoved = this.W.remove(0);
                if (instanceRemoved.classValue() == 1.0) {
                    this.maj.delete(0);
                    --this.nMajorityTotal;
                    if (this.instanceGenerated.get(instanceRemoved) == null) continue;
                    this.nGeneratedMajorityTotal -= this.instanceGenerated.get(instanceRemoved).intValue();
                    this.instanceGenerated.remove(instanceRemoved);
                    continue;
                }
                this.min.delete(0);
                --this.nMinorityTotal;
                if (this.instanceGenerated.get(instanceRemoved) == null) continue;
                this.nGeneratedMinorityTotal -= this.instanceGenerated.get(instanceRemoved).intValue();
                this.instanceGenerated.remove(instanceRemoved);
            }
        }
    }

    private double calculateRatio() {
        double ratio = 0.0;
        ratio = this.nMinorityTotal + this.nGeneratedMinorityTotal <= this.nMajorityTotal + this.nGeneratedMajorityTotal ? ((double)this.nMinorityTotal + (double)this.nGeneratedMinorityTotal) / ((double)this.nMinorityTotal + (double)this.nGeneratedMinorityTotal + (double)this.nGeneratedMajorityTotal + (double)this.nMajorityTotal) : ((double)this.nMajorityTotal + (double)this.nGeneratedMajorityTotal) / ((double)this.nMinorityTotal + (double)this.nGeneratedMinorityTotal + (double)this.nGeneratedMajorityTotal + (double)this.nMajorityTotal);
        return ratio;
    }

    private Instance onlineSMOTE() {
        Instance newInstance;
        if (this.nMinorityTotal + this.nGeneratedMinorityTotal < this.nMajorityTotal + this.nGeneratedMajorityTotal) {
            newInstance = this.generateNewInstance(this.min);
            if (newInstance != null) {
                ++this.nGeneratedMinorityTotal;
            }
        } else {
            newInstance = this.generateNewInstance(this.maj);
            if (newInstance != null) {
                ++this.nGeneratedMajorityTotal;
            }
        }
        return newInstance;
    }

    private Instance generateNewInstance(Instances minoritySamples) {
        int pos = this.classifierRandom.nextInt(minoritySamples.numInstances());
        while (this.alreadyUsed.contains(pos)) {
            pos = this.classifierRandom.nextInt(minoritySamples.numInstances());
        }
        this.alreadyUsed.add(pos);
        if (this.alreadyUsed.size() == minoritySamples.numInstances()) {
            this.alreadyUsed.clear();
        }
        Instance instanceI = minoritySamples.instance(pos);
        LinearNNSearch search = new LinearNNSearch(minoritySamples);
        try {
            Instances neighbours = ((NearestNeighbourSearch)search).kNearestNeighbours(instanceI, Math.min(this.neighbors, minoritySamples.numInstances() - 1));
            double[] values = new double[minoritySamples.numAttributes()];
            int nn = this.classifierRandom.nextInt(neighbours.numInstances());
            Enumeration attrEnum = this.samoaToWeka.wekaInstance(minoritySamples.instance(0)).enumerateAttributes();
            while (attrEnum.hasMoreElements()) {
                int iVal;
                Attribute attr = (Attribute)attrEnum.nextElement();
                if (attr.equals((Object)this.samoaToWeka.wekaInstance(minoritySamples.instance(0)).classAttribute())) continue;
                if (attr.isNumeric()) {
                    double dif = this.samoaToWeka.wekaInstance(neighbours.instance(nn)).value(attr) - this.samoaToWeka.wekaInstance(instanceI).value(attr);
                    double gap = this.classifierRandom.nextDouble();
                    values[attr.index()] = this.samoaToWeka.wekaInstance(instanceI).value(attr) + gap * dif;
                    continue;
                }
                if (attr.isDate()) {
                    double dif = this.samoaToWeka.wekaInstance(neighbours.instance(nn)).value(attr) - this.samoaToWeka.wekaInstance(instanceI).value(attr);
                    double gap = this.classifierRandom.nextDouble();
                    values[attr.index()] = (long)(this.samoaToWeka.wekaInstance(instanceI).value(attr) + gap * dif);
                    continue;
                }
                int[] valueCounts = new int[attr.numValues()];
                int n = iVal = (int)this.samoaToWeka.wekaInstance(instanceI).value(attr);
                valueCounts[n] = valueCounts[n] + 1;
                for (int nnEx = 0; nnEx < neighbours.numInstances(); ++nnEx) {
                    int val;
                    int n2 = val = (int)this.samoaToWeka.wekaInstance(neighbours.instance(nnEx)).value(attr);
                    valueCounts[n2] = valueCounts[n2] + 1;
                }
                int maxIndex = 0;
                int max = Integer.MIN_VALUE;
                for (int index = 0; index < attr.numValues(); ++index) {
                    if (valueCounts[index] <= max) continue;
                    max = valueCounts[index];
                    maxIndex = index;
                }
                values[attr.index()] = maxIndex;
            }
            values[minoritySamples.classIndex()] = instanceI.classValue();
            if (this.indexValues == null) {
                this.indexValues = new int[instanceI.numAttributes()];
                for (int i = 0; i < instanceI.numAttributes(); ++i) {
                    this.indexValues[i] = i;
                }
            }
            Instance synthetic = instanceI.copy();
            synthetic.addSparseValues(this.indexValues, values, instanceI.numAttributes());
            if (this.instanceGenerated.get(instanceI) != null) {
                this.instanceGenerated.replace(instanceI, this.instanceGenerated.get(instanceI) + 1);
            } else {
                this.instanceGenerated.put(instanceI, 1);
            }
            return synthetic;
        }
        catch (Exception e) {
            e.printStackTrace();
            return null;
        }
    }

    @Override
    public boolean isRandomizable() {
        if (this.learner != null) {
            return this.learner.isRandomizable();
        }
        return false;
    }

    @Override
    public void getModelDescription(StringBuilder arg0, int arg1) {
    }

    @Override
    protected Measurement[] getModelMeasurementsImpl() {
        return null;
    }

    @Override
    public String toString() {
        return "SMOTE online stategy using " + this.learner + " and ADWIN as sliding window";
    }
}

