/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.bayesian;

import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import jsat.classifiers.BaseUpdateableClassifier;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.DataPoint;
import jsat.classifiers.bayesian.ODE;
import jsat.exceptions.FailedToFitException;

public class AODE
extends BaseUpdateableClassifier {
    private static final long serialVersionUID = 8386506277969540732L;
    protected CategoricalData predicting;
    protected ODE[] odes;
    private double m = 20.0;

    public AODE() {
    }

    protected AODE(AODE toClone) {
        if (toClone.odes != null) {
            this.odes = new ODE[toClone.odes.length];
            for (int i = 0; i < this.odes.length; ++i) {
                this.odes[i] = toClone.odes[i].clone();
            }
            this.predicting = toClone.predicting.clone();
        }
        this.m = toClone.m;
    }

    @Override
    public AODE clone() {
        return new AODE(this);
    }

    @Override
    public void setUp(CategoricalData[] categoricalAttributes, int numericAttributes, CategoricalData predicting) {
        if (categoricalAttributes.length < 1) {
            throw new FailedToFitException("At least 2 categorical varaibles are needed for AODE");
        }
        this.predicting = predicting;
        this.odes = new ODE[categoricalAttributes.length];
        for (int i = 0; i < this.odes.length; ++i) {
            this.odes[i] = new ODE(i);
            this.odes[i].setUp(categoricalAttributes, numericAttributes, predicting);
        }
    }

    @Override
    public void trainC(final ClassificationDataSet dataSet, ExecutorService threadPool) {
        this.setUp(dataSet.getCategories(), dataSet.getNumNumericalVars(), dataSet.getPredicting());
        final CountDownLatch latch = new CountDownLatch(this.odes.length);
        for (int i = 0; i < this.odes.length; ++i) {
            final ODE ode = this.odes[i];
            threadPool.submit(new Runnable(){

                @Override
                public void run() {
                    for (int i = 0; i < dataSet.getSampleSize(); ++i) {
                        ode.update(dataSet.getDataPoint(i), dataSet.getDataPointCategory(i));
                    }
                    latch.countDown();
                }
            });
        }
        try {
            latch.await();
        }
        catch (InterruptedException ex) {
            this.trainC(dataSet);
        }
    }

    @Override
    public void update(DataPoint dataPoint, int targetClass) {
        for (ODE ode : this.odes) {
            ode.update(dataPoint, targetClass);
        }
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        CategoricalResults cr = new CategoricalResults(this.predicting.getNumOfCategories());
        int[] catVals = data.getCategoricalValues();
        for (int c = 0; c < cr.size(); ++c) {
            double prob = 0.0;
            for (ODE ode : this.odes) {
                if (ode.priors[c][catVals[ode.dependent]] < this.m) continue;
                prob += Math.exp(ode.getLogPrb(catVals, c));
            }
            cr.setProb(c, prob);
        }
        cr.normalize();
        return cr;
    }

    @Override
    public boolean supportsWeightedData() {
        return true;
    }

    public void setM(double m) {
        if (m < 0.0 || Double.isInfinite(m) || Double.isNaN(m)) {
            throw new ArithmeticException("The minimum count must be a non negative number");
        }
        this.m = m;
    }

    public double getM() {
        return this.m;
    }
}

