/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.bayesian;

import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.distributions.ContinuousDistribution;
import jsat.distributions.DistributionSearch;
import jsat.distributions.Normal;
import jsat.linear.DenseVector;
import jsat.linear.IndexValue;
import jsat.linear.Vec;
import jsat.math.MathTricks;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.utils.DoubleList;
import jsat.utils.FakeExecutor;

public class NaiveBayes
implements Classifier,
Parameterized {
    private static final long serialVersionUID = -2437775653277531182L;
    private double[][][] apriori;
    private ContinuousDistribution[][] distributions;
    private NumericalHandeling numericalHandling;
    private double[] priors;
    private boolean sparceInput = true;
    public static final NumericalHandeling defaultHandling = NumericalHandeling.NORMAL;

    public NaiveBayes(NumericalHandeling numericalHandling) {
        this.numericalHandling = numericalHandling;
    }

    public NaiveBayes() {
        this(defaultHandling);
    }

    public void setNumericalHandling(NumericalHandeling numericalHandling) {
        this.numericalHandling = numericalHandling;
    }

    public NumericalHandeling getNumericalHandling() {
        return this.numericalHandling;
    }

    public boolean isSparceInput() {
        return this.sparceInput;
    }

    public void setSparceInput(boolean sparceInput) {
        this.sparceInput = sparceInput;
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        int i;
        CategoricalResults results = new CategoricalResults(this.distributions.length);
        double[] logProbs = new double[this.distributions.length];
        Vec numVals = data.getNumericalValues();
        double maxLogProg = Double.NEGATIVE_INFINITY;
        for (i = 0; i < this.distributions.length; ++i) {
            int j;
            double logProb = 0.0;
            if (this.sparceInput) {
                Iterator<IndexValue> iter = numVals.getNonZeroIterator();
                while (iter.hasNext()) {
                    IndexValue indexValue = iter.next();
                    int j2 = indexValue.getIndex();
                    double logPDF = this.distributions[i][j2] == null ? Double.NEGATIVE_INFINITY : this.distributions[i][j2].logPdf(indexValue.getValue());
                    if (Double.isInfinite(logPDF)) {
                        logProb += Math.log(1.0E-16);
                        continue;
                    }
                    logProb += logPDF;
                }
            } else {
                for (j = 0; j < this.distributions[i].length; ++j) {
                    double logPDF = this.distributions[i][j] == null ? Double.NEGATIVE_INFINITY : this.distributions[i][j].logPdf(numVals.get(j));
                    if (Double.isInfinite(logPDF)) {
                        logProb += Math.log(1.0E-16);
                        continue;
                    }
                    logProb += logPDF;
                }
            }
            for (j = 0; j < this.apriori[i].length; ++j) {
                double p = this.apriori[i][j][data.getCategoricalValue(j)];
                logProb += Math.log(p);
            }
            logProbs[i] = logProb += Math.log(this.priors[i]);
            maxLogProg = Math.max(maxLogProg, logProb);
        }
        if (maxLogProg == Double.NEGATIVE_INFINITY) {
            for (i = 0; i < results.size(); ++i) {
                results.setProb(i, 1.0 / (double)results.size());
            }
            return results;
        }
        double denom = MathTricks.logSumExp(logProbs, maxLogProg);
        for (int i2 = 0; i2 < results.size(); ++i2) {
            results.setProb(i2, Math.exp(logProbs[i2] - denom));
        }
        results.normalize();
        return results;
    }

    @Override
    public void trainC(ClassificationDataSet dataSet) {
        this.trainC(dataSet, new FakeExecutor());
    }

    @Override
    public List<Parameter> getParameters() {
        return Collections.unmodifiableList(Parameter.getParamsFromMethods(this));
    }

    @Override
    public Parameter getParameter(String paramName) {
        return Parameter.toParameterMap(this.getParameters()).get(paramName);
    }

    @Override
    public Classifier clone() {
        int j;
        int i;
        NaiveBayes newBayes = new NaiveBayes(this.numericalHandling);
        if (this.apriori != null) {
            newBayes.apriori = new double[this.apriori.length][][];
            for (i = 0; i < this.apriori.length; ++i) {
                newBayes.apriori[i] = new double[this.apriori[i].length][];
                for (j = 0; this.apriori[i].length > 0 && j < this.apriori[i][j].length; ++j) {
                    newBayes.apriori[i][j] = Arrays.copyOf(this.apriori[i][j], this.apriori[i][j].length);
                }
            }
        }
        if (this.distributions != null) {
            newBayes.distributions = new ContinuousDistribution[this.distributions.length][];
            for (i = 0; i < this.distributions.length; ++i) {
                newBayes.distributions[i] = new ContinuousDistribution[this.distributions[i].length];
                for (j = 0; j < this.distributions[i].length; ++j) {
                    newBayes.distributions[i][j] = this.distributions[i][j].clone();
                }
            }
        }
        if (this.priors != null) {
            newBayes.priors = Arrays.copyOf(this.priors, this.priors.length);
        }
        return newBayes;
    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }

    private Vec getSampleVariableVector(ClassificationDataSet dataSet, int category, int j) {
        Vec vals = dataSet.getSampleVariableVector(category, j);
        if (this.sparceInput) {
            DoubleList nonZeroVals = new DoubleList();
            for (int i = 0; i < vals.length(); ++i) {
                if (vals.get(i) == 0.0) continue;
                nonZeroVals.add(Double.valueOf(vals.get(i)));
            }
            vals = new DenseVector(nonZeroVals);
        }
        return vals;
    }

    @Override
    public void trainC(ClassificationDataSet dataSet, ExecutorService threadPool) {
        int nCat = dataSet.getPredicting().getNumOfCategories();
        this.apriori = new double[nCat][dataSet.getNumCategoricalVars()][];
        this.distributions = new ContinuousDistribution[nCat][dataSet.getNumNumericalVars()];
        this.priors = dataSet.getPriors();
        int totalWorkers = nCat * (dataSet.getNumNumericalVars() + dataSet.getNumCategoricalVars());
        CountDownLatch latch = new CountDownLatch(totalWorkers);
        for (int i = 0; i < nCat; ++i) {
            for (int j = 0; j < dataSet.getNumNumericalVars(); ++j) {
                DistributionSelectRunable rn = new DistributionSelectRunable(i, j, this.getSampleVariableVector(dataSet, i, j), latch);
                threadPool.submit(rn);
            }
            List<DataPoint> dataSamples = dataSet.getSamples(i);
            for (int j = 0; j < dataSet.getNumCategoricalVars(); ++j) {
                this.apriori[i][j] = new double[dataSet.getCategories()[j].getNumOfCategories()];
                for (int z = 0; z < this.apriori[i][j].length; ++z) {
                    this.apriori[i][j][z] = 1.0;
                }
                AprioriCounterRunable rn = new AprioriCounterRunable(i, j, dataSamples, latch);
                threadPool.submit(rn);
            }
        }
        try {
            latch.await();
        }
        catch (InterruptedException ex) {
            ex.printStackTrace();
        }
    }

    private class AprioriCounterRunable
    implements Runnable {
        int i;
        int j;
        List<DataPoint> dataSamples;
        CountDownLatch latch;

        public AprioriCounterRunable(int i, int j, List<DataPoint> dataSamples, CountDownLatch latch) {
            this.i = i;
            this.j = j;
            this.dataSamples = dataSamples;
            this.latch = latch;
        }

        @Override
        public void run() {
            int z;
            for (DataPoint point : this.dataSamples) {
                double[] dArray = NaiveBayes.this.apriori[this.i][this.j];
                int n = point.getCategoricalValue(this.j);
                dArray[n] = dArray[n] + 1.0;
            }
            double sum = 0.0;
            for (z = 0; z < NaiveBayes.this.apriori[this.i][this.j].length; ++z) {
                sum += NaiveBayes.this.apriori[this.i][this.j][z];
            }
            z = 0;
            while (z < NaiveBayes.this.apriori[this.i][this.j].length) {
                double[] dArray = NaiveBayes.this.apriori[this.i][this.j];
                int n = z++;
                dArray[n] = dArray[n] / sum;
            }
            this.latch.countDown();
        }
    }

    private class DistributionSelectRunable
    implements Runnable {
        int i;
        int j;
        Vec v;
        CountDownLatch countDown;

        public DistributionSelectRunable(int i, int j, Vec v, CountDownLatch countDown) {
            this.i = i;
            this.j = j;
            this.v = v;
            this.countDown = countDown;
        }

        @Override
        public void run() {
            try {
                ((NaiveBayes)NaiveBayes.this).distributions[this.i][this.j] = NaiveBayes.this.numericalHandling.fit(this.v);
            }
            catch (ArithmeticException e) {
                ((NaiveBayes)NaiveBayes.this).distributions[this.i][this.j] = null;
            }
            this.countDown.countDown();
        }
    }

    public static enum NumericalHandeling {
        NORMAL{

            @Override
            protected ContinuousDistribution fit(Vec v) {
                return DistributionSearch.getBestDistribution(v, new Normal(0.0, 1.0));
            }
        }
        ,
        BEST_FIT{

            @Override
            protected ContinuousDistribution fit(Vec v) {
                return DistributionSearch.getBestDistribution(v);
            }
        }
        ,
        BEST_FIT_KDE{
            private double cutOff = 0.9;

            @Override
            protected ContinuousDistribution fit(Vec v) {
                return DistributionSearch.getBestDistribution(v, this.cutOff);
            }
        };


        protected abstract ContinuousDistribution fit(Vec var1);
    }
}

