/*
 * Decompiled with CFR 0.152.
 */
package com.aliasi.classify;

import com.aliasi.classify.Classification;
import com.aliasi.classify.Classified;
import com.aliasi.classify.ScoredClassification;
import com.aliasi.classify.ScoredClassifier;
import com.aliasi.corpus.ObjectHandler;
import com.aliasi.features.Features;
import com.aliasi.matrix.EuclideanDistance;
import com.aliasi.matrix.Vector;
import com.aliasi.symbol.MapSymbolTable;
import com.aliasi.util.AbstractExternalizable;
import com.aliasi.util.BoundedPriorityQueue;
import com.aliasi.util.Compilable;
import com.aliasi.util.Distance;
import com.aliasi.util.FeatureExtractor;
import com.aliasi.util.Proximity;
import com.aliasi.util.ScoredObject;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

public class KnnClassifier<E>
implements ScoredClassifier<E>,
ObjectHandler<Classified<E>>,
Compilable,
Serializable {
    static final long serialVersionUID = 5692985587478284405L;
    final FeatureExtractor<? super E> mFeatureExtractor;
    final int mK;
    final Proximity<Vector> mProximity;
    final boolean mWeightByProximity;
    final List<Integer> mTrainingCategories;
    final List<Vector> mTrainingVectors;
    final MapSymbolTable mFeatureSymbolTable;
    final MapSymbolTable mCategorySymbolTable;

    public KnnClassifier(FeatureExtractor<? super E> featureExtractor, int k) {
        this(featureExtractor, k, EuclideanDistance.DISTANCE);
    }

    public KnnClassifier(FeatureExtractor<? super E> featureExtractor, int k, Distance<Vector> distance) {
        this(featureExtractor, k, new ProximityWrapper(distance), false);
    }

    public KnnClassifier(FeatureExtractor<? super E> extractor, int k, Proximity<Vector> proximity, boolean weightByProximity) {
        this(extractor, k, proximity, weightByProximity, new ArrayList<Integer>(), new ArrayList<Vector>(), new MapSymbolTable(), new MapSymbolTable());
    }

    KnnClassifier(FeatureExtractor<? super E> featureExtractor, int k, Proximity<Vector> proximity, boolean weightByProximity, List<Integer> trainingCategories, List<Vector> trainingVectors, MapSymbolTable featureSymbolTable, MapSymbolTable categorySymbolTable) {
        this.mFeatureExtractor = featureExtractor;
        this.mK = k;
        this.mProximity = proximity;
        this.mWeightByProximity = weightByProximity;
        this.mTrainingCategories = trainingCategories;
        this.mTrainingVectors = trainingVectors;
        this.mFeatureSymbolTable = featureSymbolTable;
        this.mCategorySymbolTable = categorySymbolTable;
    }

    public FeatureExtractor<? super E> featureExtractor() {
        return this.mFeatureExtractor;
    }

    public Proximity<Vector> proximity() {
        return this.mProximity;
    }

    public List<String> categories() {
        ArrayList<String> catList = new ArrayList<String>();
        for (Integer i : this.mTrainingCategories) {
            catList.add(this.mCategorySymbolTable.idToSymbol(i));
        }
        return catList;
    }

    public boolean weightByProximity() {
        return this.mWeightByProximity;
    }

    public int k() {
        return this.mK;
    }

    void handle(E trainingInstance, Classification classification) {
        String category = classification.bestCategory();
        Map<String, Number> featureMap = this.mFeatureExtractor.features(trainingInstance);
        Vector vector = Features.toVectorAddSymbols(featureMap, this.mFeatureSymbolTable, 0x7FFFFFFE, false);
        this.mTrainingCategories.add(this.mCategorySymbolTable.getOrAddSymbolInteger(category));
        this.mTrainingVectors.add(vector);
    }

    @Override
    public void handle(Classified<E> classifiedObject) {
        this.handle(classifiedObject.getObject(), classifiedObject.getClassification());
    }

    @Override
    public ScoredClassification classify(E in) {
        Map<String, Number> featureMap = this.mFeatureExtractor.features(in);
        Vector inputVector = Features.toVector(featureMap, this.mFeatureSymbolTable, 0x7FFFFFFE, false);
        BoundedPriorityQueue<ScoredObject<Integer>> queue = new BoundedPriorityQueue<ScoredObject<Integer>>(ScoredObject.comparator(), this.mK);
        int i = 0;
        while (i < this.mTrainingCategories.size()) {
            Integer catId = this.mTrainingCategories.get(i);
            Vector vector = this.mTrainingVectors.get(i);
            double score = this.mProximity.proximity(inputVector, vector);
            queue.offer(new ScoredObject<Integer>(catId, score));
            ++i;
        }
        int numCats = this.mCategorySymbolTable.numSymbols();
        double[] scores = new double[numCats];
        for (ScoredObject scoredObject : queue) {
            int key = (Integer)scoredObject.getObject();
            double score = scoredObject.score();
            int n = key;
            scores[n] = scores[n] + (this.mWeightByProximity ? score : 1.0);
        }
        ArrayList<ScoredObject<String>> arrayList = new ArrayList<ScoredObject<String>>(numCats);
        int i2 = 0;
        while (i2 < numCats) {
            arrayList.add(new ScoredObject<String>(this.mCategorySymbolTable.idToSymbol(i2), scores[i2]));
            ++i2;
        }
        return ScoredClassification.create(arrayList);
    }

    Object writeReplace() {
        return new Serializer(this);
    }

    @Override
    public void compileTo(ObjectOutput out) throws IOException {
        out.writeObject(this.writeReplace());
    }

    static class ProximityWrapper
    implements Proximity<Vector>,
    Serializable {
        static final long serialVersionUID = -1410999733708772109L;
        Distance<Vector> mDistance;

        public ProximityWrapper() {
        }

        public ProximityWrapper(Distance<Vector> distance) {
            this.mDistance = distance;
        }

        @Override
        public double proximity(Vector v1, Vector v2) {
            double d = this.mDistance.distance(v1, v2);
            return d < 0.0 ? Double.MAX_VALUE : 1.0 / (1.0 + d);
        }
    }

    static class Serializer<F>
    extends AbstractExternalizable {
        static final long serialVersionUID = 4951969636521202268L;
        final KnnClassifier<F> mClassifier;

        public Serializer() {
            this(null);
        }

        public Serializer(KnnClassifier<F> classifier) {
            this.mClassifier = classifier;
        }

        @Override
        public void writeExternal(ObjectOutput out) throws IOException {
            AbstractExternalizable.serializeOrCompile(this.mClassifier.mFeatureExtractor, out);
            out.writeInt(this.mClassifier.mK);
            AbstractExternalizable.serializeOrCompile(this.mClassifier.mProximity, out);
            out.writeBoolean(this.mClassifier.mWeightByProximity);
            int numInstances = this.mClassifier.mTrainingCategories.size();
            out.writeInt(numInstances);
            int i = 0;
            while (i < numInstances) {
                out.writeInt(this.mClassifier.mTrainingCategories.get(i));
                ++i;
            }
            i = 0;
            while (i < numInstances) {
                AbstractExternalizable.serializeOrCompile(this.mClassifier.mTrainingVectors.get(i), out);
                ++i;
            }
            AbstractExternalizable.serializeOrCompile(this.mClassifier.mFeatureSymbolTable, out);
            AbstractExternalizable.serializeOrCompile(this.mClassifier.mCategorySymbolTable, out);
        }

        @Override
        public Object read(ObjectInput in) throws ClassNotFoundException, IOException {
            FeatureExtractor featureExtractor = (FeatureExtractor)in.readObject();
            int k = in.readInt();
            Proximity proximity = (Proximity)in.readObject();
            boolean weightByProximity = in.readBoolean();
            int numInstances = in.readInt();
            ArrayList<Integer> categoryList = new ArrayList<Integer>(numInstances);
            int i = 0;
            while (i < numInstances) {
                categoryList.add(in.readInt());
                ++i;
            }
            ArrayList<Vector> vectorList = new ArrayList<Vector>(numInstances);
            int i2 = 0;
            while (i2 < numInstances) {
                vectorList.add((Vector)in.readObject());
                ++i2;
            }
            MapSymbolTable featureSymbolTable = (MapSymbolTable)in.readObject();
            MapSymbolTable categorySymbolTable = (MapSymbolTable)in.readObject();
            return new KnnClassifier(featureExtractor, k, proximity, weightByProximity, categoryList, vectorList, featureSymbolTable, categorySymbolTable);
        }
    }

    static class TrainingInstance {
        final String mCategory;
        final Vector mVector;

        TrainingInstance(String category, Vector vector) {
            this.mCategory = category;
            this.mVector = vector;
        }
    }
}

