/*
 * Decompiled with CFR 0.152.
 */
package moa.classifiers.lazy;

import com.github.javacliparser.FlagOption;
import com.github.javacliparser.FloatOption;
import com.github.javacliparser.IntOption;
import com.yahoo.labs.samoa.instances.Instance;
import com.yahoo.labs.samoa.instances.InstanceImpl;
import com.yahoo.labs.samoa.instances.Instances;
import com.yahoo.labs.samoa.instances.InstancesHeader;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.TreeSet;
import moa.capabilities.CapabilitiesHandler;
import moa.capabilities.Capability;
import moa.capabilities.ImmutableCapabilities;
import moa.classifiers.AbstractClassifier;
import moa.classifiers.MultiClassClassifier;
import moa.clusterers.kmeanspm.CoresetKMeans;
import moa.core.Measurement;

public class SAMkNN
extends AbstractClassifier
implements MultiClassClassifier,
CapabilitiesHandler {
    private static final long serialVersionUID = 1L;
    public IntOption kOption = new IntOption("k", 'k', "The number of neighbors", 5, 1, Integer.MAX_VALUE);
    public IntOption limitOption = new IntOption("limit", 'w', "The maximum number of instances to store", 5000, 1, Integer.MAX_VALUE);
    public IntOption minSTMSizeOption = new IntOption("minSTMSize", 'm', "The minimum number of instances in the STM", 50, 1, Integer.MAX_VALUE);
    public FloatOption relativeLTMSizeOption = new FloatOption("relativeLTMSize", 'p', "The allowed LTM size relative to the total limit.", 0.4, 0.0, 1.0);
    public FlagOption recalculateSTMErrorOption = new FlagOption("recalculateError", 'r', "Recalculates the error rate of the STM for size adaption (Costly operation). Otherwise, an approximation is used.");
    private int maxClassValue = 0;
    private Instances stm;
    private Instances ltm;
    private int maxLTMSize;
    private int maxSTMSize;
    private List<Integer> stmHistory;
    private List<Integer> ltmHistory;
    private List<Integer> cmHistory;
    private double[][] distanceMatrixSTM;
    private Map<Integer, List<Integer>> predictionHistories;
    private Random random;

    @Override
    public String getPurposeString() {
        return "SAMkNN: special.";
    }

    protected void init() {
        this.maxLTMSize = (int)(this.relativeLTMSizeOption.getValue() * (double)this.limitOption.getValue());
        this.maxSTMSize = this.limitOption.getValue() - this.maxLTMSize;
        this.stmHistory = new ArrayList<Integer>();
        this.ltmHistory = new ArrayList<Integer>();
        this.cmHistory = new ArrayList<Integer>();
        this.distanceMatrixSTM = new double[this.limitOption.getValue() + 1][this.limitOption.getValue() + 1];
        this.predictionHistories = new HashMap<Integer, List<Integer>>();
        this.random = new Random();
    }

    @Override
    public void setModelContext(InstancesHeader context) {
        try {
            this.stm = new Instances(context, 0);
            this.stm.setClassIndex(context.classIndex());
            this.ltm = new Instances(context, 0);
            this.ltm.setClassIndex(context.classIndex());
            this.init();
        }
        catch (Exception e) {
            System.err.println("Error: no Model Context available.");
            e.printStackTrace();
            System.exit(1);
        }
    }

    @Override
    public void resetLearningImpl() {
        this.stm = null;
        this.ltm = null;
        this.stmHistory = null;
        this.ltmHistory = null;
        this.cmHistory = null;
        this.distanceMatrixSTM = null;
        this.predictionHistories = null;
    }

    @Override
    public void trainOnInstanceImpl(Instance inst) {
        if (inst.classValue() > (double)this.maxClassValue) {
            this.maxClassValue = (int)inst.classValue();
        }
        this.stm.add(inst);
        this.memorySizeCheck();
        this.clean(this.stm, this.ltm, true);
        double[] distancesSTM = this.get1ToNDistances(inst, this.stm);
        for (int i = 0; i < this.stm.numInstances(); ++i) {
            this.distanceMatrixSTM[this.stm.numInstances() - 1][i] = distancesSTM[i];
        }
        int oldWindowSize = this.stm.numInstances();
        int newWindowSize = this.getNewSTMSize(this.recalculateSTMErrorOption.isSet());
        if (newWindowSize < oldWindowSize) {
            int i;
            int diff = oldWindowSize - newWindowSize;
            Instances discardedSTMInstances = new Instances(this.stm, 0);
            for (i = diff; i > 0; --i) {
                discardedSTMInstances.add(this.stm.get(0).copy());
                this.stm.delete(0);
            }
            for (i = 0; i < this.stm.numInstances(); ++i) {
                for (int j = 0; j < this.stm.numInstances(); ++j) {
                    this.distanceMatrixSTM[i][j] = this.distanceMatrixSTM[diff + i][diff + j];
                }
            }
            for (i = 0; i < diff; ++i) {
                if (this.stmHistory.size() > 0) {
                    this.stmHistory.remove(0);
                }
                if (this.ltmHistory.size() > 0) {
                    this.ltmHistory.remove(0);
                }
                if (this.cmHistory.size() <= 0) continue;
                this.cmHistory.remove(0);
            }
            this.clean(this.stm, discardedSTMInstances, false);
            for (i = 0; i < discardedSTMInstances.numInstances(); ++i) {
                this.ltm.add(discardedSTMInstances.get(i).copy());
            }
            this.memorySizeCheck();
        }
    }

    @Override
    public double[] getVotesForInstance(Instance inst) {
        double[] v;
        int predClassSTM = 0;
        int predClassLTM = 0;
        int predClassCM = 0;
        try {
            if (this.stm.numInstances() > 0) {
                double[] vLTM;
                double[] distancesSTM = this.get1ToNDistances(inst, this.stm);
                int[] nnIndicesSTM = this.nArgMin(Math.min(distancesSTM.length, this.kOption.getValue()), distancesSTM);
                double[] vSTM = this.getDistanceWeightedVotes(distancesSTM, nnIndicesSTM, this.stm);
                predClassSTM = this.getClassFromVotes(vSTM);
                double[] distancesLTM = this.get1ToNDistances(inst, this.ltm);
                double[] vCM = this.getCMVotes(distancesSTM, this.stm, distancesLTM, this.ltm);
                predClassCM = this.getClassFromVotes(vCM);
                if (this.ltm.numInstances() >= 0) {
                    int[] nnIndicesLTM = this.nArgMin(Math.min(distancesLTM.length, this.kOption.getValue()), distancesLTM);
                    vLTM = this.getDistanceWeightedVotes(distancesLTM, nnIndicesLTM, this.ltm);
                    predClassLTM = this.getClassFromVotes(vLTM);
                } else {
                    vLTM = new double[inst.numClasses()];
                }
                int correctSTM = this.historySum(this.stmHistory);
                int correctLTM = this.historySum(this.ltmHistory);
                int correctCM = this.historySum(this.cmHistory);
                v = correctSTM >= correctLTM && correctSTM >= correctCM ? vSTM : (correctLTM > correctSTM && correctLTM >= correctCM ? vLTM : vCM);
            } else {
                v = new double[inst.numClasses()];
            }
            this.stmHistory.add((double)predClassSTM == inst.classValue() ? 1 : 0);
            this.ltmHistory.add((double)predClassLTM == inst.classValue() ? 1 : 0);
            this.cmHistory.add((double)predClassCM == inst.classValue() ? 1 : 0);
        }
        catch (Exception e) {
            return new double[inst.numClasses()];
        }
        return v;
    }

    @Override
    protected Measurement[] getModelMeasurementsImpl() {
        return null;
    }

    @Override
    public void getModelDescription(StringBuilder out, int indent) {
    }

    @Override
    public boolean isRandomizable() {
        return false;
    }

    private int historySum(List<Integer> history) {
        int sum = 0;
        for (Integer e : history) {
            sum += e.intValue();
        }
        return sum;
    }

    private List<double[]> kMeans(List<double[]> points, int k) {
        List<double[]> centroids = CoresetKMeans.generatekMeansPlusPlusCentroids(k, points, this.random);
        CoresetKMeans.kMeans(centroids, points);
        return centroids;
    }

    private void clusterDown() {
        int classIndex = this.ltm.classIndex();
        for (int c = 0; c <= this.maxClassValue; ++c) {
            ArrayList<double[]> classSamples = new ArrayList<double[]>();
            for (int i = this.ltm.numInstances() - 1; i > -1; --i) {
                if (this.ltm.get(i).classValue() != (double)c) continue;
                classSamples.add(this.ltm.get(i).toDoubleArray());
                this.ltm.delete(i);
            }
            if (classSamples.size() <= 0) continue;
            for (double[] sample : classSamples) {
                if (classIndex != 0) {
                    sample[classIndex] = sample[0];
                }
                sample[0] = 1.0;
            }
            List<double[]> centroids = this.kMeans(classSamples, Math.max(classSamples.size() / 2, 1));
            for (double[] centroid : centroids) {
                double[] attributes = new double[this.ltm.numAttributes()];
                System.arraycopy(centroid, 0, attributes, 1, this.ltm.numAttributes() - 1);
                if (classIndex != 0) {
                    attributes[0] = attributes[classIndex];
                }
                attributes[classIndex] = c;
                InstanceImpl inst = new InstanceImpl(1.0, attributes);
                inst.setDataset(this.ltm);
                this.ltm.add(inst);
            }
        }
    }

    private void memorySizeCheck() {
        if (this.stm.numInstances() + this.ltm.numInstances() > this.maxSTMSize + this.maxLTMSize) {
            if (this.ltm.numInstances() > this.maxLTMSize) {
                this.clusterDown();
            } else {
                int i;
                int numShifts = this.maxLTMSize - this.ltm.numInstances() + 1;
                for (i = 0; i < numShifts; ++i) {
                    this.ltm.add(this.stm.get(0).copy());
                    this.stm.delete(0);
                    this.stmHistory.remove(0);
                    this.ltmHistory.remove(0);
                    this.cmHistory.remove(0);
                }
                this.clusterDown();
                this.predictionHistories.clear();
                for (i = 0; i < this.stm.numInstances(); ++i) {
                    for (int j = 0; j < this.stm.numInstances(); ++j) {
                        this.distanceMatrixSTM[i][j] = this.distanceMatrixSTM[numShifts + i][numShifts + j];
                    }
                }
            }
        }
    }

    private void cleanSingle(Instances cleanAgainst, int cleanAgainstindex, Instances toClean) {
        Instances cleanAgainstTmp = new Instances(cleanAgainst);
        cleanAgainstTmp.delete(cleanAgainstindex);
        double[] distancesSTM = this.get1ToNDistances(cleanAgainst.get(cleanAgainstindex), cleanAgainstTmp);
        int[] nnIndicesSTM = this.nArgMin(Math.min(this.kOption.getValue(), distancesSTM.length), distancesSTM);
        double[] distancesLTM = this.get1ToNDistances(cleanAgainst.get(cleanAgainstindex), toClean);
        int[] nnIndicesLTM = this.nArgMin(Math.min(this.kOption.getValue(), distancesLTM.length), distancesLTM);
        double distThreshold = 0.0;
        for (int nnIdx : nnIndicesSTM) {
            if (cleanAgainstTmp.get(nnIdx).classValue() != cleanAgainst.get(cleanAgainstindex).classValue() || !(distancesSTM[nnIdx] > distThreshold)) continue;
            distThreshold = distancesSTM[nnIdx];
        }
        ArrayList<Integer> delIndices = new ArrayList<Integer>();
        for (int nnIdx : nnIndicesLTM) {
            if (toClean.get(nnIdx).classValue() == cleanAgainst.get(cleanAgainstindex).classValue() || !(distancesLTM[nnIdx] <= distThreshold)) continue;
            delIndices.add(nnIdx);
        }
        Collections.sort(delIndices, Collections.reverseOrder());
        Object object = delIndices.iterator();
        while (object.hasNext()) {
            Integer idx = (Integer)object.next();
            toClean.delete(idx);
        }
    }

    private void clean(Instances cleanAgainst, Instances toClean, boolean onlyLast) {
        if (cleanAgainst.numInstances() > this.kOption.getValue() && toClean.numInstances() > 0) {
            if (onlyLast) {
                this.cleanSingle(cleanAgainst, cleanAgainst.numInstances() - 1, toClean);
            } else {
                for (int i = 0; i < cleanAgainst.numInstances(); ++i) {
                    this.cleanSingle(cleanAgainst, i, toClean);
                }
            }
        }
    }

    private double[] getDistanceWeightedVotes(double[] distances, int[] nnIndices, Instances instances) {
        double[] v = new double[this.maxClassValue + 1];
        for (int nnIdx : nnIndices) {
            int n = (int)instances.instance(nnIdx).classValue();
            v[n] = v[n] + 1.0 / Math.max(distances[nnIdx], 1.0E-9);
        }
        return v;
    }

    private double[] getDistanceWeightedVotesCM(double[] distances, int[] nnIndices, Instances stm, Instances ltm) {
        double[] v = new double[this.maxClassValue + 1];
        for (int nnIdx : nnIndices) {
            if (nnIdx < stm.numInstances()) {
                int n = (int)stm.instance(nnIdx).classValue();
                v[n] = v[n] + 1.0 / Math.max(distances[nnIdx], 1.0E-9);
                continue;
            }
            int n = (int)ltm.instance(nnIdx - stm.numInstances()).classValue();
            v[n] = v[n] + 1.0 / Math.max(distances[nnIdx], 1.0E-9);
        }
        return v;
    }

    private double[] getCMVotes(double[] distancesSTM, Instances stm, double[] distancesLTM, Instances ltm) {
        double[] distancesCM = new double[distancesSTM.length + distancesLTM.length];
        System.arraycopy(distancesSTM, 0, distancesCM, 0, distancesSTM.length);
        System.arraycopy(distancesLTM, 0, distancesCM, distancesSTM.length, distancesLTM.length);
        int[] nnIndicesCM = this.nArgMin(Math.min(distancesCM.length, this.kOption.getValue()), distancesCM);
        return this.getDistanceWeightedVotesCM(distancesCM, nnIndicesCM, stm, ltm);
    }

    private int getClassFromVotes(double[] votes) {
        double maxVote = -1.0;
        int maxVoteClass = -1;
        for (int i = 0; i < votes.length; ++i) {
            if (!(votes[i] > maxVote)) continue;
            maxVote = votes[i];
            maxVoteClass = i;
        }
        return maxVoteClass;
    }

    private int getLabelFct(double[] distances, Instances instances, int startIdx, int endIdx) {
        int[] nnIndices = this.nArgMin(Math.min(this.kOption.getValue(), distances.length), distances, startIdx, endIdx);
        double[] votes = this.getDistanceWeightedVotes(distances, nnIndices, instances);
        return this.getClassFromVotes(votes);
    }

    private double getDistance(Instance sample, Instance sample2) {
        double sum = 0.0;
        for (int i = 0; i < sample.numInputAttributes(); ++i) {
            double diff = sample.valueInputAttribute(i) - sample2.valueInputAttribute(i);
            sum += diff * diff;
        }
        return Math.sqrt(sum);
    }

    private double[] get1ToNDistances(Instance sample, Instances samples) {
        double[] distances = new double[samples.numInstances()];
        for (int i = 0; i < samples.numInstances(); ++i) {
            distances[i] = this.getDistance(sample, samples.get(i));
        }
        return distances;
    }

    private int[] nArgMin(int n, double[] values, int startIdx, int endIdx) {
        int[] indices = new int[n];
        for (int i = 0; i < n; ++i) {
            double minValue = Double.MAX_VALUE;
            for (int j = startIdx; j < endIdx + 1; ++j) {
                if (!(values[j] < minValue)) continue;
                boolean alreadyUsed = false;
                for (int k = 0; k < i; ++k) {
                    if (indices[k] != j) continue;
                    alreadyUsed = true;
                }
                if (alreadyUsed) continue;
                indices[i] = j;
                minValue = values[j];
            }
        }
        return indices;
    }

    private int[] nArgMin(int n, double[] values) {
        return this.nArgMin(n, values, 0, values.length - 1);
    }

    private void adaptHistories(int numberOfDeletions) {
        for (int i = 0; i < numberOfDeletions; ++i) {
            TreeSet<Integer> keys = new TreeSet<Integer>(this.predictionHistories.keySet());
            this.predictionHistories.remove(keys.first());
            keys = new TreeSet<Integer>(this.predictionHistories.keySet());
            for (Integer key : keys) {
                List<Integer> predHistory = this.predictionHistories.remove(key);
                this.predictionHistories.put(key - (Integer)keys.first(), predHistory);
            }
        }
    }

    private List<Integer> getIncrementalTestTrainPredHistory(Instances instances, int startIdx, List<Integer> predictionHistory) {
        for (int i = startIdx + this.kOption.getValue() + predictionHistory.size(); i < instances.numInstances(); ++i) {
            predictionHistory.add((double)this.getLabelFct(this.distanceMatrixSTM[i], instances, startIdx, i - 1) == instances.get(i).classValue() ? 1 : 0);
        }
        return predictionHistory;
    }

    private List<Integer> getTestTrainPredHistory(Instances instances, int startIdx) {
        ArrayList<Integer> predictionHistory = new ArrayList<Integer>();
        for (int i = startIdx + this.kOption.getValue(); i < instances.numInstances(); ++i) {
            predictionHistory.add((double)this.getLabelFct(this.distanceMatrixSTM[i], instances, startIdx, i - 1) == instances.get(i).classValue() ? 1 : 0);
        }
        return predictionHistory;
    }

    private int getMinErrorRateWindowSize() {
        int numSamples = this.stm.numInstances();
        if (numSamples < 2 * this.minSTMSizeOption.getValue()) {
            return numSamples;
        }
        ArrayList<Integer> numSamplesRange = new ArrayList<Integer>();
        numSamplesRange.add(numSamples);
        while ((Integer)numSamplesRange.get(numSamplesRange.size() - 1) >= 2 * this.minSTMSizeOption.getValue()) {
            numSamplesRange.add((Integer)numSamplesRange.get(numSamplesRange.size() - 1) / 2);
        }
        Iterator<Integer> it = this.predictionHistories.keySet().iterator();
        while (it.hasNext()) {
            Integer key = it.next();
            if (numSamplesRange.contains(numSamples - key)) continue;
            it.remove();
        }
        ArrayList<Double> errorRates = new ArrayList<Double>();
        for (Integer numSamplesIt : numSamplesRange) {
            int idx = numSamples - numSamplesIt;
            List<Integer> predHistory = this.predictionHistories.containsKey(idx) ? this.getIncrementalTestTrainPredHistory(this.stm, idx, this.predictionHistories.get(idx)) : this.getTestTrainPredHistory(this.stm, idx);
            this.predictionHistories.put(idx, predHistory);
            errorRates.add(this.getHistoryErrorRate(predHistory));
        }
        int minErrorRateIdx = errorRates.indexOf(Collections.min(errorRates));
        int windowSize = (Integer)numSamplesRange.get(minErrorRateIdx);
        if (windowSize < numSamples) {
            this.adaptHistories(minErrorRateIdx);
        }
        return windowSize;
    }

    private double getHistoryErrorRate(List<Integer> predHistory) {
        double sumCorrect = 0.0;
        for (Integer e : predHistory) {
            sumCorrect += (double)e.intValue();
        }
        return 1.0 - sumCorrect / (double)predHistory.size();
    }

    private int getMinErrorRateWindowSizeIncremental() {
        int windowSize;
        List<Integer> predHistory;
        int idx;
        int numSamples = this.stm.numInstances();
        if (numSamples < 2 * this.minSTMSizeOption.getValue()) {
            return numSamples;
        }
        ArrayList<Integer> numSamplesRange = new ArrayList<Integer>();
        numSamplesRange.add(numSamples);
        while ((Integer)numSamplesRange.get(numSamplesRange.size() - 1) >= 2 * this.minSTMSizeOption.getValue()) {
            numSamplesRange.add((Integer)numSamplesRange.get(numSamplesRange.size() - 1) / 2);
        }
        ArrayList<Double> errorRates = new ArrayList<Double>();
        for (Integer numSamplesIt : numSamplesRange) {
            idx = numSamples - numSamplesIt;
            if (this.predictionHistories.containsKey(idx)) {
                predHistory = this.getIncrementalTestTrainPredHistory(this.stm, idx, this.predictionHistories.get(idx));
            } else if (this.predictionHistories.containsKey(idx - 1)) {
                predHistory = this.predictionHistories.remove(idx - 1);
                predHistory.remove(0);
                predHistory = this.getIncrementalTestTrainPredHistory(this.stm, idx, predHistory);
                this.predictionHistories.put(idx, predHistory);
            } else {
                predHistory = this.getTestTrainPredHistory(this.stm, idx);
                this.predictionHistories.put(idx, predHistory);
            }
            errorRates.add(this.getHistoryErrorRate(predHistory));
        }
        int minErrorRateIdx = errorRates.indexOf(Collections.min(errorRates));
        if (minErrorRateIdx > 0) {
            for (int i = 1; i < errorRates.size(); ++i) {
                if (!((Double)errorRates.get(i) < (Double)errorRates.get(0))) continue;
                idx = numSamples - (Integer)numSamplesRange.get(i);
                predHistory = this.getTestTrainPredHistory(this.stm, idx);
                errorRates.set(i, this.getHistoryErrorRate(predHistory));
                this.predictionHistories.remove(idx);
                this.predictionHistories.put(idx, predHistory);
            }
            minErrorRateIdx = errorRates.indexOf(Collections.min(errorRates));
        }
        if ((windowSize = ((Integer)numSamplesRange.get(minErrorRateIdx)).intValue()) < numSamples) {
            this.adaptHistories(minErrorRateIdx);
        }
        return windowSize;
    }

    private int getNewSTMSize(boolean recalculateErrors) {
        if (recalculateErrors) {
            return this.getMinErrorRateWindowSize();
        }
        return this.getMinErrorRateWindowSizeIncremental();
    }

    @Override
    public ImmutableCapabilities defineImmutableCapabilities() {
        if (this.getClass() == SAMkNN.class) {
            return new ImmutableCapabilities(Capability.VIEW_STANDARD, Capability.VIEW_LITE);
        }
        return new ImmutableCapabilities(Capability.VIEW_STANDARD);
    }
}

