/*
 * Decompiled with CFR 0.152.
 */
package com.aliasi.test.unit.dca;

import com.aliasi.dca.DiscreteChooser;
import com.aliasi.dca.DiscreteObjectChooser;
import com.aliasi.io.Reporter;
import com.aliasi.matrix.DenseVector;
import com.aliasi.matrix.Vector;
import com.aliasi.stats.AnnealingSchedule;
import com.aliasi.stats.RegressionPrior;
import com.aliasi.symbol.SymbolTable;
import com.aliasi.util.AbstractExternalizable;
import com.aliasi.util.FeatureExtractor;
import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
import junit.framework.Assert;
import org.junit.Test;

public class DiscreteObjectChooserTest {
    @Test
    public void testSim() throws IOException {
        int numSamples = 1000;
        double[] simCoeffs = new double[]{0.0, 3.0, -2.0, 1.0};
        int numDims = simCoeffs.length;
        DenseVector simCoeffVector = new DenseVector(simCoeffs);
        DiscreteChooser simChooser = new DiscreteChooser(simCoeffVector);
        Random random = new Random(42L);
        Vector[][] alternativess = new Vector[numSamples][];
        int[] choices = new int[numSamples];
        int i = 0;
        while (i < numSamples) {
            int numChoices = 1 + random.nextInt(8);
            alternativess[i] = new Vector[numChoices];
            int k = 0;
            while (k < numChoices) {
                double[] xs = new double[numDims];
                xs[0] = 1.0;
                int d = 1;
                while (d < numDims) {
                    xs[d] = 2.0 * random.nextGaussian();
                    ++d;
                }
                alternativess[i][k] = new DenseVector(xs);
                ++k;
            }
            double[] choiceProbs = simChooser.choiceProbs(alternativess[i]);
            double choiceProb = random.nextDouble();
            double cumProb = 0.0;
            int k2 = 0;
            while (k2 < numChoices) {
                if (choiceProb < (cumProb += choiceProbs[k2]) || k2 == numChoices - 1) {
                    choices[i] = k2;
                    break;
                }
                ++k2;
            }
            ++i;
        }
        double priorVariance = 5.0;
        boolean nonInformativeIntercept = true;
        RegressionPrior prior = RegressionPrior.gaussian(priorVariance, nonInformativeIntercept);
        int priorBlockSize = 100;
        double initialLearningRate = 0.1;
        double decayBase = 0.99;
        AnnealingSchedule annealingSchedule = AnnealingSchedule.exponential(initialLearningRate, decayBase);
        double minImprovement = 1.0E-5;
        int minEpochs = 5;
        int maxEpochs = 500;
        Reporter reporter = null;
        HashMap<Integer, Vector> vectorMap = new HashMap<Integer, Vector>();
        ArrayList alternativeObjectss = new ArrayList(alternativess.length);
        int count = 0;
        int i2 = 0;
        while (i2 < alternativess.length) {
            ArrayList<Integer> alternativeObjects = new ArrayList<Integer>(alternativess[i2].length);
            alternativeObjectss.add(alternativeObjects);
            int j = 0;
            while (j < alternativess[i2].length) {
                Integer obj = count++;
                vectorMap.put(obj, alternativess[i2][j]);
                alternativeObjects.add(obj);
                ++j;
            }
            ++i2;
        }
        MapFeatureExtractor featureExtractor = new MapFeatureExtractor(vectorMap);
        int minFeatureCount = 5;
        DiscreteObjectChooser<Integer> objectChooser = DiscreteObjectChooser.estimate(featureExtractor, alternativeObjectss, choices, minFeatureCount, prior, priorBlockSize, annealingSchedule, minImprovement, minEpochs, maxEpochs, reporter);
        DiscreteChooser chooser = objectChooser.chooser();
        SymbolTable featureSymbolTable = objectChooser.featureSymbolTable();
        Vector coeffVector = chooser.coefficients();
        int d = 0;
        while (d < coeffVector.numDimensions()) {
            Assert.assertEquals((double)simCoeffVector.value(d), (double)coeffVector.value(featureSymbolTable.symbolToID(Integer.toString(d))), (double)0.1);
            ++d;
        }
        DiscreteObjectChooser deserChooser = (DiscreteObjectChooser)AbstractExternalizable.serializeDeserialize(objectChooser);
        Vector deserCoeffVector = deserChooser.chooser().coefficients();
        SymbolTable deserSymTab = deserChooser.featureSymbolTable();
        int d2 = 0;
        while (d2 < deserCoeffVector.numDimensions()) {
            Assert.assertEquals((double)simCoeffVector.value(d2), (double)deserCoeffVector.value(deserSymTab.symbolToID(Integer.toString(d2))), (double)0.1);
            ++d2;
        }
    }

    static class MapFeatureExtractor
    implements FeatureExtractor<Integer>,
    Serializable {
        final Map<Integer, Vector> mMap;

        MapFeatureExtractor(Map<Integer, Vector> map) {
            this.mMap = map;
        }

        @Override
        public Map<String, Double> features(Integer i) {
            Vector v = this.mMap.get(i);
            HashMap<String, Double> result = new HashMap<String, Double>(5);
            int d = 0;
            while (d < v.numDimensions()) {
                result.put(Integer.toString(d), v.value(d));
                ++d;
            }
            return result;
        }
    }
}

