/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.bayesian;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.bayesian.MultivariateNormals;
import jsat.distributions.multivariate.MultivariateDistribution;
import jsat.exceptions.FailedToFitException;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;

public class BestClassDistribution
implements Classifier,
Parameterized {
    private static final long serialVersionUID = -1746145372146154228L;
    private MultivariateDistribution baseDist;
    private List<MultivariateDistribution> dists;
    private double[] priors;
    private boolean usePriors;
    public static final boolean USE_PRIORS = true;

    public BestClassDistribution(MultivariateDistribution baseDist) {
        this(baseDist, true);
    }

    public BestClassDistribution(MultivariateDistribution baseDist, boolean usePriors) {
        this.baseDist = baseDist;
        this.usePriors = usePriors;
    }

    public BestClassDistribution(BestClassDistribution toCopy) {
        if (toCopy.priors != null) {
            this.priors = Arrays.copyOf(toCopy.priors, toCopy.priors.length);
        }
        this.baseDist = toCopy.baseDist.clone();
        if (toCopy.dists != null) {
            this.dists = new ArrayList<MultivariateDistribution>(toCopy.dists.size());
            for (MultivariateDistribution md : toCopy.dists) {
                this.dists.add(md == null ? null : md.clone());
            }
        }
    }

    public void setUsePriors(boolean usePriors) {
        this.usePriors = usePriors;
    }

    public boolean isUsePriors() {
        return this.usePriors;
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        CategoricalResults cr = new CategoricalResults(this.dists.size());
        for (int i = 0; i < this.dists.size(); ++i) {
            if (this.dists.get(i) == null) continue;
            double p = 0.0;
            try {
                p = this.dists.get(i).pdf(data.getNumericalValues());
            }
            catch (ArithmeticException ex) {
                // empty catch block
            }
            if (this.usePriors) {
                p *= this.priors[i];
            }
            cr.setProb(i, p);
        }
        cr.normalize();
        return cr;
    }

    @Override
    public void trainC(ClassificationDataSet dataSet, ExecutorService threadPool) {
        try {
            this.dists = new ArrayList<MultivariateDistribution>();
            this.priors = dataSet.getPriors();
            ArrayList<Future<MultivariateDistribution>> newDists = new ArrayList<Future<MultivariateDistribution>>();
            final MultivariateDistribution sourceDist = this.baseDist;
            for (int i = 0; i < dataSet.getPredicting().getNumOfCategories(); ++i) {
                final List<DataPoint> list = dataSet.getSamples(i);
                Future<MultivariateDistribution> tmp = threadPool.submit(new Callable<MultivariateDistribution>(){

                    @Override
                    public MultivariateDistribution call() throws Exception {
                        if (list.isEmpty()) {
                            return null;
                        }
                        MultivariateDistribution dist = sourceDist.clone();
                        dist.setUsingDataList(list);
                        return dist;
                    }
                });
                newDists.add(tmp);
            }
            for (Future future : newDists) {
                this.dists.add((MultivariateDistribution)future.get());
            }
        }
        catch (Exception ex) {
            Logger.getLogger(MultivariateNormals.class.getName()).log(Level.SEVERE, null, ex);
            throw new FailedToFitException(ex);
        }
    }

    @Override
    public void trainC(ClassificationDataSet dataSet) {
        this.priors = dataSet.getPriors();
        this.dists = new ArrayList<MultivariateDistribution>(dataSet.getClassSize());
        for (int i = 0; i < dataSet.getClassSize(); ++i) {
            MultivariateDistribution dist = this.baseDist.clone();
            List<DataPoint> samp = dataSet.getSamples(i);
            if (samp.isEmpty()) {
                this.dists.add(null);
                continue;
            }
            dist.setUsingDataList(samp);
            this.dists.add(dist);
        }
    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }

    @Override
    public Classifier clone() {
        return new BestClassDistribution(this);
    }

    @Override
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override
    public Parameter getParameter(String paramName) {
        return Parameter.toParameterMap(this.getParameters()).get(paramName);
    }
}

