/*
 * Decompiled with CFR 0.152.
 */
package jsat.distributions.multivariate;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.ExecutorService;
import jsat.classifiers.DataPoint;
import jsat.distributions.empirical.kernelfunc.EpanechnikovKF;
import jsat.distributions.empirical.kernelfunc.KernelFunction;
import jsat.distributions.multivariate.MultivariateKDE;
import jsat.exceptions.UntrainedModelException;
import jsat.linear.Vec;
import jsat.linear.VecPaired;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.EuclideanDistance;
import jsat.linear.distancemetrics.TrainableDistanceMetric;
import jsat.linear.vectorcollection.DefaultVectorCollectionFactory;
import jsat.linear.vectorcollection.VectorCollection;
import jsat.linear.vectorcollection.VectorCollectionFactory;
import jsat.linear.vectorcollection.VectorCollectionUtils;
import jsat.math.OnLineStatistics;
import jsat.parameters.DoubleParameter;
import jsat.parameters.IntParameter;
import jsat.parameters.KernelFunctionParameter;
import jsat.parameters.MetricParameter;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;

public class MetricKDE
extends MultivariateKDE
implements Parameterized {
    private static final long serialVersionUID = -2084039950938740815L;
    private KernelFunction kf;
    private double bandwidth;
    private DistanceMetric distanceMetric;
    private VectorCollectionFactory<VecPaired<Vec, Integer>> vcf;
    private VectorCollection<VecPaired<Vec, Integer>> vecCollection;
    private int defaultK;
    private double defaultStndDev;
    private static final VectorCollectionFactory<VecPaired<Vec, Integer>> defaultVCF = new DefaultVectorCollectionFactory<VecPaired<Vec, Integer>>();
    public static final int DEFAULT_K = 3;
    public static final double DEFAULT_STND_DEV = 2.0;
    public static final KernelFunction DEFAULT_KF = EpanechnikovKF.getInstance();
    private final List<Parameter> parameters = Collections.unmodifiableList(new ArrayList<Parameter>(){
        private static final long serialVersionUID = -2830924861210733734L;
        {
            this.add(new KernelFunctionParameter(){
                private static final long serialVersionUID = 560041843101841185L;

                @Override
                public KernelFunction getObject() {
                    return MetricKDE.this.getKernelFunction();
                }

                @Override
                public boolean setObject(KernelFunction obj) {
                    MetricKDE.this.setKernelFunction(obj);
                    return true;
                }
            });
            this.add(new MetricParameter(){
                private static final long serialVersionUID = 1506569342529820853L;

                @Override
                public boolean setMetric(DistanceMetric val) {
                    MetricKDE.this.setDistanceMetric(val);
                    return true;
                }

                @Override
                public DistanceMetric getMetric() {
                    return MetricKDE.this.getDistanceMetric();
                }
            });
            this.add(new IntParameter(){
                private static final long serialVersionUID = 2109791176169136850L;

                @Override
                public int getValue() {
                    return MetricKDE.this.getDefaultK();
                }

                @Override
                public boolean setValue(int val) {
                    if (val < 1) {
                        return false;
                    }
                    MetricKDE.this.setDefaultK(val);
                    return true;
                }

                @Override
                public String getASCIIName() {
                    return "k Neighbors for Bandwidth Estimation";
                }
            });
            this.add(new DoubleParameter(){
                private static final long serialVersionUID = 685333554755596799L;

                @Override
                public double getValue() {
                    return MetricKDE.this.getDefaultStndDev();
                }

                @Override
                public boolean setValue(double val) {
                    try {
                        MetricKDE.this.setDefaultStndDev(val);
                        return true;
                    }
                    catch (ArithmeticException e) {
                        return false;
                    }
                }

                @Override
                public String getASCIIName() {
                    return "Standard Deviations for Bandwidth Estimation";
                }
            });
        }
    });
    private final Map<String, Parameter> paramMap = Parameter.toParameterMap(this.parameters);

    public MetricKDE() {
        this(DEFAULT_KF, new EuclideanDistance(), defaultVCF);
    }

    public MetricKDE(DistanceMetric distanceMetric) {
        this(DEFAULT_KF, distanceMetric, defaultVCF);
    }

    public MetricKDE(DistanceMetric distanceMetric, VectorCollectionFactory<VecPaired<Vec, Integer>> vcf) {
        this(DEFAULT_KF, distanceMetric, vcf);
    }

    public MetricKDE(KernelFunction kf, DistanceMetric distanceMetric) {
        this(kf, distanceMetric, new DefaultVectorCollectionFactory<VecPaired<Vec, Integer>>());
    }

    public MetricKDE(KernelFunction kf, DistanceMetric distanceMetric, VectorCollectionFactory<VecPaired<Vec, Integer>> vcf) {
        this(kf, distanceMetric, vcf, 3, 2.0);
    }

    public MetricKDE(KernelFunction kf, DistanceMetric distanceMetric, VectorCollectionFactory<VecPaired<Vec, Integer>> vcf, int defaultK, double defaultStndDev) {
        this.setKernelFunction(kf);
        this.distanceMetric = distanceMetric;
        this.vcf = vcf;
        this.setDefaultK(defaultK);
        this.setDefaultStndDev(defaultStndDev);
    }

    public void setBandwith(double bandwidth) {
        if (bandwidth <= 0.0 || Double.isNaN(bandwidth) || Double.isInfinite(bandwidth)) {
            throw new ArithmeticException("Invalid bandwith given, bandwith must be a positive number, not " + bandwidth);
        }
        this.bandwidth = bandwidth;
    }

    public double getBandwith() {
        return this.bandwidth;
    }

    public void setDefaultK(int defaultK) {
        if (defaultK <= 0) {
            throw new ArithmeticException("At least one neighbor must be taken into acount, " + defaultK + " is invalid");
        }
        this.defaultK = defaultK;
    }

    public int getDefaultK() {
        return this.defaultK;
    }

    public void setDefaultStndDev(double defaultStndDev) {
        if (Double.isInfinite(defaultStndDev) || Double.isNaN(defaultStndDev) || defaultStndDev <= 0.0) {
            throw new ArithmeticException("The number of standard deviations to remove must bea postive number, not " + defaultStndDev);
        }
        this.defaultStndDev = defaultStndDev;
    }

    public double getDefaultStndDev() {
        return this.defaultStndDev;
    }

    public DistanceMetric getDistanceMetric() {
        return this.distanceMetric;
    }

    public void setDistanceMetric(DistanceMetric distanceMetric) {
        this.distanceMetric = distanceMetric;
    }

    @Override
    public MetricKDE clone() {
        MetricKDE clone = new MetricKDE(this.kf, this.distanceMetric.clone(), this.vcf.clone(), this.defaultK, this.defaultStndDev);
        clone.bandwidth = this.bandwidth;
        if (this.vecCollection != null) {
            clone.vecCollection = this.vecCollection.clone();
        }
        return clone;
    }

    @Override
    public List<? extends VecPaired<VecPaired<Vec, Integer>, Double>> getNearby(Vec x) {
        if (this.vecCollection == null) {
            throw new UntrainedModelException("Model has not yet been created");
        }
        List<? extends VecPaired<VecPaired<Vec, Integer>, Double>> nearBy = this.getNearbyRaw(x);
        for (VecPaired<VecPaired<Vec, Integer>, Double> vecPaired : nearBy) {
            vecPaired.setPair(this.kf.k(vecPaired.getPair()));
        }
        return nearBy;
    }

    @Override
    public List<? extends VecPaired<VecPaired<Vec, Integer>, Double>> getNearbyRaw(Vec x) {
        if (this.vecCollection == null) {
            throw new UntrainedModelException("Model has not yet been created");
        }
        List<VecPaired<VecPaired<Vec, Integer>, Double>> nearBy = this.vecCollection.search(x, this.bandwidth * this.kf.cutOff());
        for (VecPaired<VecPaired<Vec, Integer>, Double> result : nearBy) {
            result.setPair(result.getPair() / this.bandwidth);
        }
        return nearBy;
    }

    @Override
    public double pdf(Vec x) {
        List<? extends VecPaired<VecPaired<Vec, Integer>, Double>> nearBy = this.getNearby(x);
        if (nearBy.isEmpty()) {
            return 0.0;
        }
        double PDF = 0.0;
        for (VecPaired<VecPaired<Vec, Integer>, Double> vecPaired : nearBy) {
            PDF += vecPaired.getPair().doubleValue();
        }
        return PDF / ((double)this.vecCollection.size() * Math.pow(this.bandwidth, nearBy.get(0).length()));
    }

    public <V extends Vec> boolean setUsingData(List<V> dataSet, double bandwith) {
        return this.setUsingData(dataSet, bandwith, null);
    }

    public <V extends Vec> boolean setUsingData(List<V> dataSet, double bandwith, ExecutorService threadpool) {
        this.setBandwith(bandwith);
        ArrayList<VecPaired<Vec, Integer>> indexVectorPair = new ArrayList<VecPaired<Vec, Integer>>(dataSet.size());
        for (int i = 0; i < dataSet.size(); ++i) {
            indexVectorPair.add(new VecPaired<Vec, Integer>((Vec)dataSet.get(i), i));
        }
        TrainableDistanceMetric.trainIfNeeded(this.distanceMetric, dataSet, threadpool);
        this.vecCollection = threadpool == null ? this.vcf.getVectorCollection(indexVectorPair, this.distanceMetric) : this.vcf.getVectorCollection(indexVectorPair, this.distanceMetric, threadpool);
        return true;
    }

    public <V extends Vec> boolean setUsingData(List<V> dataSet, int k) {
        return this.setUsingData(dataSet, k, this.defaultStndDev);
    }

    public <V extends Vec> boolean setUsingData(List<V> dataSet, int k, ExecutorService threadpool) {
        return this.setUsingData(dataSet, k, this.defaultStndDev, threadpool);
    }

    public <V extends Vec> boolean setUsingData(List<V> dataSet, int k, double stndDevs) {
        return this.setUsingData(dataSet, k, stndDevs, null);
    }

    public <V extends Vec> boolean setUsingData(List<V> dataSet, int k, double stndDevs, ExecutorService threadpool) {
        ArrayList<VecPaired<Vec, Integer>> indexVectorPair = new ArrayList<VecPaired<Vec, Integer>>(dataSet.size());
        for (int i = 0; i < dataSet.size(); ++i) {
            indexVectorPair.add(new VecPaired<Vec, Integer>((Vec)dataSet.get(i), i));
        }
        TrainableDistanceMetric.trainIfNeeded(this.distanceMetric, dataSet, threadpool);
        this.vecCollection = this.vcf.getVectorCollection(indexVectorPair, this.distanceMetric);
        OnLineStatistics stats = threadpool == null ? VectorCollectionUtils.getKthNeighborStats(this.vecCollection, dataSet, k + 1) : VectorCollectionUtils.getKthNeighborStats(this.vecCollection, dataSet, k + 1, threadpool);
        this.setBandwith(stats.getMean() + stats.getStandardDeviation() * stndDevs);
        return true;
    }

    @Override
    public <V extends Vec> boolean setUsingData(List<V> dataSet) {
        return this.setUsingData(dataSet, this.defaultK);
    }

    @Override
    public <V extends Vec> boolean setUsingData(List<V> dataSet, ExecutorService threadpool) {
        return this.setUsingData(dataSet, this.defaultK, threadpool);
    }

    @Override
    public boolean setUsingDataList(List<DataPoint> dataPoints) {
        ArrayList<Vec> dataSet = new ArrayList<Vec>(dataPoints.size());
        for (DataPoint dp : dataPoints) {
            dataSet.add(dp.getNumericalValues());
        }
        return this.setUsingData(dataSet);
    }

    @Override
    public boolean setUsingDataList(List<DataPoint> dataPoints, ExecutorService threadpool) {
        ArrayList<Vec> dataSet = new ArrayList<Vec>(dataPoints.size());
        for (DataPoint dp : dataPoints) {
            dataSet.add(dp.getNumericalValues());
        }
        return this.setUsingData(dataSet, threadpool);
    }

    @Override
    public List<Vec> sample(int count, Random rand) {
        throw new UnsupportedOperationException("Not supported yet.");
    }

    @Override
    public KernelFunction getKernelFunction() {
        return this.kf;
    }

    public void setKernelFunction(KernelFunction kf) {
        this.kf = kf;
    }

    @Override
    public void scaleBandwidth(double scale) {
        this.bandwidth *= scale;
    }

    @Override
    public List<Parameter> getParameters() {
        return this.parameters;
    }

    @Override
    public Parameter getParameter(String paramName) {
        return this.paramMap.get(paramName);
    }
}

