/*
 * Decompiled with CFR 0.152.
 */
package com.aliasi.test.unit.classify;

import com.aliasi.classify.Classification;
import com.aliasi.classify.Classified;
import com.aliasi.classify.KnnClassifier;
import com.aliasi.classify.RankedClassification;
import com.aliasi.classify.ScoredClassification;
import com.aliasi.matrix.EuclideanDistance;
import com.aliasi.matrix.Vector;
import com.aliasi.tokenizer.IndoEuropeanTokenizerFactory;
import com.aliasi.tokenizer.TokenFeatureExtractor;
import com.aliasi.util.AbstractExternalizable;
import com.aliasi.util.Distance;
import com.aliasi.util.Proximity;
import java.io.IOException;
import java.io.Serializable;
import junit.framework.Assert;
import org.junit.Test;

public class KnnClassifierTest {
    static final TokenFeatureExtractor FEATURE_EXTRACTOR = new TokenFeatureExtractor(IndoEuropeanTokenizerFactory.INSTANCE);
    static final Distance<Vector> DISTANCE = EuclideanDistance.DISTANCE;

    static void handle(KnnClassifier classifier, String input, Classification c) {
        classifier.handle(new Classified<String>(input, c));
    }

    @Test
    public void testOne() throws ClassNotFoundException, IOException {
        String[] train = new String[]{"a a b", "a b b"};
        String[] cats = new String[]{"A", "B"};
        KnnClassifier<CharSequence> classifier = new KnnClassifier<CharSequence>(FEATURE_EXTRACTOR, 1);
        int i = 0;
        while (i < train.length) {
            KnnClassifierTest.handle(classifier, train[i], new Classification(cats[i]));
            ++i;
        }
        Classification classification = classifier.classify((Object)"a a a a b b");
        Assert.assertEquals((String)"A", (String)classification.bestCategory());
        Assert.assertEquals((String)"A", (String)((RankedClassification)classification).category(0));
        Assert.assertEquals((String)"B", (String)((RankedClassification)classification).category(1));
        Assert.assertEquals((Object)1.0, (Object)((ScoredClassification)classification).score(0));
        Assert.assertEquals((Object)0.0, (Object)((ScoredClassification)classification).score(1));
        KnnClassifier classifier2 = (KnnClassifier)AbstractExternalizable.serializeDeserialize(classifier);
        classifier2.classify("a a a a b b");
        Assert.assertEquals((String)"A", (String)classification.bestCategory());
        Assert.assertEquals((String)"A", (String)((RankedClassification)classification).category(0));
        Assert.assertEquals((String)"B", (String)((RankedClassification)classification).category(1));
        Assert.assertEquals((Object)1.0, (Object)((ScoredClassification)classification).score(0));
        Assert.assertEquals((Object)0.0, (Object)((ScoredClassification)classification).score(1));
    }

    @Test
    public void testTwo() throws ClassNotFoundException, IOException {
        String[] train = new String[]{"a a b", "a b b", "a a a", "a a a a a b", "a a b b"};
        String[] cats = new String[]{"A", "B", "A", "A", "B"};
        KnnClassifier<CharSequence> classifier = new KnnClassifier<CharSequence>(FEATURE_EXTRACTOR, 3);
        int i = 0;
        while (i < train.length) {
            KnnClassifierTest.handle(classifier, train[i], new Classification(cats[i]));
            ++i;
        }
        Classification classification = classifier.classify((Object)"a a b");
        Assert.assertEquals((String)"A", (String)classification.bestCategory());
        Assert.assertEquals((String)"A", (String)((RankedClassification)classification).category(0));
        Assert.assertEquals((String)"B", (String)((RankedClassification)classification).category(1));
        Assert.assertEquals((Object)2.0, (Object)((ScoredClassification)classification).score(0));
        Assert.assertEquals((Object)1.0, (Object)((ScoredClassification)classification).score(1));
        KnnClassifier classifier2 = (KnnClassifier)AbstractExternalizable.serializeDeserialize(classifier);
        classification = classifier2.classify("a a b");
        Assert.assertEquals((String)"A", (String)classification.bestCategory());
        Assert.assertEquals((String)"A", (String)((RankedClassification)classification).category(0));
        Assert.assertEquals((String)"B", (String)((RankedClassification)classification).category(1));
        Assert.assertEquals((Object)2.0, (Object)((ScoredClassification)classification).score(0));
        Assert.assertEquals((Object)1.0, (Object)((ScoredClassification)classification).score(1));
    }

    @Test
    public void testThree() {
        String[] train = new String[]{"a a b", "a b b", "a a a", "b b b"};
        String[] cats = new String[]{"A", "B", "A", "B"};
        KnnClassifier<CharSequence> classifier = new KnnClassifier<CharSequence>(FEATURE_EXTRACTOR, Integer.MAX_VALUE, new TestProximity(), true);
        int i = 0;
        while (i < train.length) {
            KnnClassifierTest.handle(classifier, train[i], new Classification(cats[i]));
            ++i;
        }
        double prox01 = 1.0 / (1.0 + Math.sqrt(KnnClassifierTest.sqrDiff(2.0, 1.0) + KnnClassifierTest.sqrDiff(1.0, 2.0)));
        double prox02 = 1.0 / (1.0 + Math.sqrt(KnnClassifierTest.sqrDiff(2.0, 3.0) + KnnClassifierTest.sqrDiff(1.0, 0.0)));
        double prox03 = 1.0 / (1.0 + Math.sqrt(KnnClassifierTest.sqrDiff(2.0, 0.0) + KnnClassifierTest.sqrDiff(1.0, 3.0)));
        double prox12 = 1.0 / (1.0 + Math.sqrt(KnnClassifierTest.sqrDiff(1.0, 3.0) + KnnClassifierTest.sqrDiff(1.0, 0.0)));
        double prox13 = 1.0 / (1.0 + Math.sqrt(KnnClassifierTest.sqrDiff(1.0, 0.0) + KnnClassifierTest.sqrDiff(2.0, 3.0)));
        double prox23 = 1.0 / (1.0 + Math.sqrt(KnnClassifierTest.sqrDiff(3.0, 0.0) + KnnClassifierTest.sqrDiff(0.0, 3.0)));
        ScoredClassification[] classifications = new ScoredClassification[train.length];
        int i2 = 0;
        while (i2 < train.length) {
            classifications[i2] = classifier.classify((Object)train[i2]);
            ++i2;
        }
        i2 = 0;
        while (i2 < train.length) {
            Assert.assertEquals((String)cats[i2], (String)classifications[i2].bestCategory());
            ++i2;
        }
    }

    static double sqrDiff(double x1, double x2) {
        double diff = x1 - x2;
        return diff * diff;
    }

    static class TestProximity
    implements Proximity<Vector>,
    Serializable {
        TestProximity() {
        }

        @Override
        public double proximity(Vector v1, Vector v2) {
            return 1.0 / (1.0 + DISTANCE.distance(v1, v2));
        }
    }
}

