/*
 * Decompiled with CFR 0.152.
 */
package moa.classifiers.lazy;

import com.github.javacliparser.FlagOption;
import com.github.javacliparser.IntOption;
import com.github.javacliparser.MultiChoiceOption;
import com.yahoo.labs.samoa.instances.Instance;
import com.yahoo.labs.samoa.instances.Instances;
import com.yahoo.labs.samoa.instances.InstancesHeader;
import java.util.Arrays;
import moa.classifiers.AbstractClassifier;
import moa.classifiers.MultiClassClassifier;
import moa.classifiers.Regressor;
import moa.classifiers.lazy.neighboursearch.KDTree;
import moa.classifiers.lazy.neighboursearch.LinearNNSearch;
import moa.classifiers.lazy.neighboursearch.NearestNeighbourSearch;
import moa.core.Measurement;

public class kNN
extends AbstractClassifier
implements MultiClassClassifier,
Regressor {
    private static final long serialVersionUID = 1L;
    public IntOption kOption = new IntOption("k", 'k', "The number of neighbors", 10, 1, Integer.MAX_VALUE);
    public FlagOption medianOption = new FlagOption("median", 'm', "median or mean");
    public IntOption limitOption = new IntOption("limit", 'w', "The maximum number of instances to store", 1000, 1, Integer.MAX_VALUE);
    public MultiChoiceOption nearestNeighbourSearchOption = new MultiChoiceOption("nearestNeighbourSearch", 'n', "Nearest Neighbour Search to use", new String[]{"LinearNN", "KDTree"}, new String[]{"Brute force search algorithm for nearest neighbour search. ", "KDTree search algorithm for nearest neighbour search"}, 0);
    int C = 0;
    protected Instances window;

    @Override
    public String getPurposeString() {
        return "kNN: special.";
    }

    @Override
    public void setModelContext(InstancesHeader context) {
        try {
            this.window = new Instances(context, 0);
            this.window.setClassIndex(context.classIndex());
        }
        catch (Exception e) {
            System.err.println("Error: no Model Context available.");
            e.printStackTrace();
            System.exit(1);
        }
    }

    @Override
    public void resetLearningImpl() {
        this.window = null;
    }

    @Override
    public void trainOnInstanceImpl(Instance inst) {
        if (inst.classValue() > (double)this.C) {
            this.C = (int)inst.classValue();
        }
        if (this.window == null) {
            this.window = new Instances(inst.dataset());
        }
        if (this.limitOption.getValue() <= this.window.numInstances()) {
            this.window.delete(0);
        }
        this.window.add(inst);
    }

    @Override
    public double[] getVotesForInstance(Instance inst) {
        double[] v = new double[this.C + 1];
        try {
            NearestNeighbourSearch search;
            if (this.nearestNeighbourSearchOption.getChosenIndex() == 0) {
                search = new LinearNNSearch(this.window);
            } else {
                search = new KDTree();
                search.setInstances(this.window);
            }
            if (this.window.numInstances() > 0) {
                Instances neighbours = search.kNearestNeighbours(inst, Math.min(this.kOption.getValue(), this.window.numInstances()));
                if (inst.classAttribute().isNumeric()) {
                    double[] result = new double[1];
                    double sum = 0.0;
                    int num = neighbours.numInstances();
                    if (this.medianOption.isSet()) {
                        double[] classValues = new double[num];
                        for (int i = 0; i < num; ++i) {
                            classValues[i] = neighbours.instance(i).classValue();
                        }
                        Arrays.sort(classValues);
                        result[0] = classValues.length % 2 == 1 ? classValues[num / 2] : (classValues[num / 2 - 1] + classValues[num / 2]) / 2.0;
                        return result;
                    }
                    for (int i = 0; i < num; ++i) {
                        sum += neighbours.instance(i).classValue();
                    }
                    result[0] = sum / (double)num;
                    return result;
                }
                for (int i = 0; i < neighbours.numInstances(); ++i) {
                    int n = (int)neighbours.instance(i).classValue();
                    v[n] = v[n] + 1.0;
                }
            }
        }
        catch (Exception e) {
            return new double[inst.numClasses()];
        }
        return v;
    }

    @Override
    protected Measurement[] getModelMeasurementsImpl() {
        return null;
    }

    @Override
    public void getModelDescription(StringBuilder out, int indent) {
    }

    @Override
    public boolean isRandomizable() {
        return false;
    }
}

