/*
 * Decompiled with CFR 0.152.
 */
package com.aliasi.test.unit.dca;

import com.aliasi.dca.DiscreteChooser;
import com.aliasi.io.Reporter;
import com.aliasi.matrix.DenseVector;
import com.aliasi.matrix.Vector;
import com.aliasi.stats.AnnealingSchedule;
import com.aliasi.stats.RegressionPrior;
import com.aliasi.test.unit.Asserts;
import com.aliasi.util.AbstractExternalizable;
import com.aliasi.util.Math;
import java.io.IOException;
import java.util.Random;
import junit.framework.Assert;
import org.junit.Test;

public class DiscreteChooserTest {
    @Test
    public void testSim() throws IOException {
        int numSamples = 1000;
        double[] simCoeffs = new double[]{0.0, 3.0, -2.0, 1.0};
        int numDims = simCoeffs.length;
        DenseVector simCoeffVector = new DenseVector(simCoeffs);
        DiscreteChooser simChooser = new DiscreteChooser(simCoeffVector);
        Random random = new Random(42L);
        Vector[][] alternativess = new Vector[numSamples][];
        int[] choices = new int[numSamples];
        int i = 0;
        while (i < numSamples) {
            int numChoices = 1 + random.nextInt(8);
            alternativess[i] = new Vector[numChoices];
            int k = 0;
            while (k < numChoices) {
                double[] xs = new double[numDims];
                xs[0] = 1.0;
                int d = 1;
                while (d < numDims) {
                    xs[d] = 2.0 * random.nextGaussian();
                    ++d;
                }
                alternativess[i][k] = new DenseVector(xs);
                ++k;
            }
            double[] choiceProbs = simChooser.choiceProbs(alternativess[i]);
            double choiceProb = random.nextDouble();
            double cumProb = 0.0;
            int k2 = 0;
            while (k2 < numChoices) {
                if (choiceProb < (cumProb += choiceProbs[k2]) || k2 == numChoices - 1) {
                    choices[i] = k2;
                    break;
                }
                ++k2;
            }
            ++i;
        }
        double priorVariance = 5.0;
        boolean nonInformativeIntercept = true;
        RegressionPrior prior = RegressionPrior.gaussian(priorVariance, nonInformativeIntercept);
        int priorBlockSize = 100;
        double initialLearningRate = 0.1;
        double decayBase = 0.99;
        AnnealingSchedule annealingSchedule = AnnealingSchedule.exponential(initialLearningRate, decayBase);
        double minImprovement = 1.0E-5;
        int minEpochs = 5;
        int maxEpochs = 500;
        Reporter reporter = null;
        DiscreteChooser chooser = DiscreteChooser.estimate(alternativess, choices, prior, priorBlockSize, annealingSchedule, minImprovement, minEpochs, maxEpochs, reporter);
        Vector coeffVector = chooser.coefficients();
        int d = 0;
        while (d < coeffVector.numDimensions()) {
            Assert.assertEquals((double)simCoeffVector.value(d), (double)coeffVector.value(d), (double)0.1);
            ++d;
        }
        DiscreteChooser deserChooser = (DiscreteChooser)AbstractExternalizable.serializeDeserialize(chooser);
        Vector deserCoeffVector = deserChooser.coefficients();
        int d2 = 0;
        while (d2 < coeffVector.numDimensions()) {
            Assert.assertEquals((double)coeffVector.value(d2), (double)deserCoeffVector.value(d2), (double)1.0E-5);
            ++d2;
        }
    }

    @Test
    public void testChoice() throws IOException {
        this.assertChoice(new double[0], new double[]{0.2, 0.8}, new double[0][]);
        this.assertChoice(new double[0], new double[]{0.2, 0.8}, new double[][]{{-1.0, 1.0}});
        this.assertChoice(new double[0], new double[]{0.2, -1.2, 0.8}, {-1.0, 1.0, 1.0}, {2.0, 1.0, -1.0}, {-1.0, -1.0, -21.0}, {-1.0, 2.0, 1.0}, {1.0, -2.0, -1.0});
    }

    void assertChoice(double[] expectedBases, double[] coeffs, double[] ... inputs) throws IOException {
        DenseVector coeffVector = new DenseVector(coeffs);
        DiscreteChooser chooser = new DiscreteChooser(coeffVector);
        this.assertChoice(coeffVector, chooser, expectedBases, coeffs, inputs);
        DiscreteChooser serDeserChooser = (DiscreteChooser)AbstractExternalizable.serializeDeserialize(chooser);
        this.assertChoice(coeffVector, serDeserChooser, expectedBases, coeffs, inputs);
    }

    void assertChoice(Vector coeffVector, DiscreteChooser chooser, double[] expectedBases, double[] coeffs, double[][] inputs) {
        Vector[] inputVecs = new Vector[inputs.length];
        int i = 0;
        while (i < inputs.length) {
            inputVecs[i] = new DenseVector(inputs[i]);
            ++i;
        }
        if (inputVecs.length == 0) {
            try {
                chooser.choose(inputVecs);
                Assert.fail();
            }
            catch (IllegalArgumentException e) {
                Asserts.succeed();
            }
            try {
                chooser.choiceProbs(inputVecs);
                Assert.fail();
            }
            catch (IllegalArgumentException e) {
                Asserts.succeed();
            }
            try {
                chooser.choiceLogProbs(inputVecs);
                Assert.fail();
            }
            catch (IllegalArgumentException e) {
                Asserts.succeed();
            }
            return;
        }
        int choice = chooser.choose(inputVecs);
        double[] choiceProbs = chooser.choiceProbs(inputVecs);
        double[] choiceLogProbs = chooser.choiceLogProbs(inputVecs);
        double[] bases = new double[inputs.length];
        int i2 = 0;
        while (i2 < bases.length) {
            bases[i2] = inputVecs[i2].dotProduct(coeffVector);
            ++i2;
        }
        double[] expBases = new double[inputs.length];
        int i3 = 0;
        while (i3 < expBases.length) {
            expBases[i3] = java.lang.Math.exp(bases[i3]);
            ++i3;
        }
        double Z = 0.0;
        int i4 = 0;
        while (i4 < expBases.length) {
            Z += expBases[i4];
            ++i4;
        }
        double[] expProbs = new double[inputs.length];
        int i5 = 0;
        while (i5 < expProbs.length) {
            expProbs[i5] = expBases[i5] / Z;
            ++i5;
        }
        double[] expLogProbs = new double[inputs.length];
        int i6 = 0;
        while (i6 < expLogProbs.length) {
            expLogProbs[i6] = java.lang.Math.log(expProbs[i6]);
            ++i6;
        }
        int expChoice = 0;
        int i7 = 1;
        while (i7 < expBases.length) {
            if (expBases[i7] > expBases[expChoice]) {
                expChoice = i7;
            }
            ++i7;
        }
        Assert.assertEquals((int)expChoice, (int)choice);
        Asserts.assertEqualsArray(expProbs, choiceProbs, 0.001);
        Asserts.assertEqualsArray(expLogProbs, choiceLogProbs, 0.001);
        Assert.assertEquals((double)Math.sum(choiceProbs), (double)1.0, (double)0.001);
    }
}

