/*
 * Decompiled with CFR 0.152.
 */
package jsat.datatransform.featureselection;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.DataSet;
import jsat.classifiers.ClassificationDataSet;
import jsat.datatransform.RemoveAttributeTransform;
import jsat.exceptions.FailedToFitException;
import jsat.linear.DenseVector;
import jsat.linear.Vec;
import jsat.linear.VecPaired;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.EuclideanDistance;
import jsat.linear.distancemetrics.TrainableDistanceMetric;
import jsat.linear.vectorcollection.DefaultVectorCollectionFactory;
import jsat.linear.vectorcollection.VectorCollection;
import jsat.linear.vectorcollection.VectorCollectionFactory;
import jsat.utils.FakeExecutor;
import jsat.utils.IndexTable;
import jsat.utils.IntSet;
import jsat.utils.SystemInfo;
import jsat.utils.random.RandomUtil;

public class ReliefF
extends RemoveAttributeTransform {
    private static final long serialVersionUID = -3336500245613075520L;
    private double[] w;
    private int featureCount;
    private int iterations;
    private int neighbors;
    private DistanceMetric dm;
    private VectorCollectionFactory<Vec> vcf = new DefaultVectorCollectionFactory<Vec>();

    public ReliefF(int featureCount) {
        this(featureCount, 100, 15, new EuclideanDistance(), new DefaultVectorCollectionFactory<Vec>());
    }

    public ReliefF(int featureCount, int m, int n, DistanceMetric dm) {
        this(featureCount, m, n, dm, new DefaultVectorCollectionFactory<Vec>());
    }

    public ReliefF(ClassificationDataSet cds, int featureCount, int m, int n, DistanceMetric dm) {
        this(cds, featureCount, m, n, dm, new DefaultVectorCollectionFactory<Vec>());
    }

    public ReliefF(ClassificationDataSet cds, int featureCount, int m, int n, DistanceMetric dm, ExecutorService threadPool) {
        this(cds, featureCount, m, n, dm, new DefaultVectorCollectionFactory<Vec>(), threadPool);
    }

    public ReliefF(ClassificationDataSet cds, int featureCount, int m, int n, DistanceMetric dm, VectorCollectionFactory<Vec> vcf) {
        this(cds, featureCount, m, n, dm, vcf, null);
    }

    protected ReliefF(ReliefF toCopy) {
        super(toCopy);
        if (toCopy.w != null) {
            this.w = Arrays.copyOf(toCopy.w, toCopy.w.length);
        }
        this.dm = toCopy.dm.clone();
        this.featureCount = toCopy.featureCount;
        this.iterations = toCopy.iterations;
        this.neighbors = toCopy.neighbors;
        this.vcf = toCopy.vcf.clone();
    }

    public ReliefF(int featureCount, int m, int n, DistanceMetric dm, VectorCollectionFactory<Vec> vcf) {
        this.setFeatureCount(featureCount);
        this.setIterations(m);
        this.setNeighbors(n);
        this.setDistanceMetric(dm);
        this.vcf = vcf;
    }

    public ReliefF(ClassificationDataSet cds, int featureCount, int m, int n, DistanceMetric dm, VectorCollectionFactory<Vec> vcf, ExecutorService threadPool) {
        this(featureCount, m, n, dm, vcf);
        this.fit(cds, threadPool);
    }

    @Override
    public void fit(DataSet data) {
        this.fit(data, null);
    }

    public void fit(DataSet data, ExecutorService threadPool) {
        int toUse;
        int i;
        if (!(data instanceof ClassificationDataSet)) {
            throw new FailedToFitException("RelifF only works with classification datasets, not " + data.getClass().getSimpleName());
        }
        final ClassificationDataSet cds = (ClassificationDataSet)data;
        super.fit(data);
        this.w = new double[cds.getNumNumericalVars()];
        double[] minVals = new double[this.w.length];
        Arrays.fill(minVals, Double.POSITIVE_INFINITY);
        final double[] normalizer = new double[this.w.length];
        Arrays.fill(normalizer, Double.NEGATIVE_INFINITY);
        final double[] priors = cds.getPriors();
        final List<Vec> allVecs = cds.getDataVectors();
        for (Vec v : allVecs) {
            for (i = 0; i < v.length(); ++i) {
                minVals[i] = Math.min(minVals[i], v.get(i));
                normalizer[i] = Math.max(normalizer[i], v.get(i));
            }
        }
        for (int i2 = 0; i2 < normalizer.length; ++i2) {
            int n = i2;
            normalizer[n] = normalizer[n] - minVals[i2];
        }
        final ArrayList<VectorCollection<Vec>> classVC = new ArrayList<VectorCollection<Vec>>(priors.length);
        TrainableDistanceMetric.trainIfNeeded(this.dm, cds, threadPool);
        int curStart = 0;
        for (i = 0; i < priors.length; ++i) {
            int classCount = cds.classSampleCount(i);
            if (threadPool == null) {
                classVC.add(this.vcf.getVectorCollection(allVecs.subList(curStart, curStart + classCount), this.dm));
            } else {
                classVC.add(this.vcf.getVectorCollection(allVecs.subList(curStart, curStart + classCount), this.dm, threadPool));
            }
            curStart += classCount;
        }
        final int m = this.iterations;
        final int n = this.neighbors;
        int n2 = toUse = threadPool == null ? 1 : SystemInfo.LogicalCores;
        if (threadPool == null) {
            threadPool = new FakeExecutor();
        }
        int blockSize = m / toUse;
        final CountDownLatch latch = new CountDownLatch(toUse);
        for (int id = 0; id < toUse; ++id) {
            final int mm = id < m % toUse ? blockSize + 1 : blockSize;
            threadPool.submit(new Runnable(){

                /*
                 * WARNING - Removed try catching itself - possible behaviour change.
                 */
                @Override
                public void run() {
                    double[] wLocal = new double[ReliefF.this.w.length];
                    Random rand = RandomUtil.getRandom();
                    for (int iter = 0; iter < mm; ++iter) {
                        int k = rand.nextInt(cds.getSampleSize());
                        Vec x_k = (Vec)allVecs.get(k);
                        int y_k = cds.getDataPointCategory(k);
                        for (int y = 0; y < priors.length; ++y) {
                            int searchFor = y == y_k ? n + 1 : n;
                            List nNearestC = ((VectorCollection)classVC.get(y)).search(x_k, searchFor);
                            if (searchFor != n) {
                                nNearestC = nNearestC.subList(1, searchFor);
                            }
                            for (int i = 0; i < ReliefF.this.w.length; ++i) {
                                for (VecPaired x_jy : nNearestC) {
                                    if (y == y_k) {
                                        int n2 = i;
                                        wLocal[n2] = wLocal[n2] - ReliefF.this.diff(i, x_k, x_jy.getVector(), normalizer) / (double)(m * n);
                                        continue;
                                    }
                                    int n3 = i;
                                    wLocal[n3] = wLocal[n3] + priors[y] / (1.0 - priors[y_k]) * ReliefF.this.diff(i, x_k, x_jy.getVector(), normalizer) / (double)(m * n);
                                }
                            }
                        }
                    }
                    double[] dArray = ReliefF.this.w;
                    synchronized (dArray) {
                        for (int i = 0; i < ReliefF.this.w.length; ++i) {
                            double[] dArray2 = ReliefF.this.w;
                            int n4 = i;
                            dArray2[n4] = dArray2[n4] + wLocal[i];
                        }
                    }
                    latch.countDown();
                }
            });
        }
        try {
            latch.await();
        }
        catch (InterruptedException ex) {
            Logger.getLogger(ReliefF.class.getName()).log(Level.SEVERE, null, ex);
        }
        IndexTable it = new IndexTable(this.w);
        IntSet numericalToRemove = new IntSet(this.w.length * 2);
        for (int i3 = 0; i3 < this.w.length - this.featureCount; ++i3) {
            numericalToRemove.add(Integer.valueOf(it.index(i3)));
        }
        this.setUp(cds, Collections.EMPTY_SET, numericalToRemove);
    }

    public Vec getWeights() {
        return new DenseVector(this.w);
    }

    private double diff(int i, Vec xj, Vec xk, double[] normalzer) {
        if (normalzer[i] == 0.0) {
            return 0.0;
        }
        return Math.abs(xj.get(i) - xk.get(i)) / normalzer[i];
    }

    @Override
    public ReliefF clone() {
        return new ReliefF(this);
    }

    public void setFeatureCount(int featureCount) {
        if (featureCount < 1) {
            throw new IllegalArgumentException("Number of features to select must be positive, not " + featureCount);
        }
        this.featureCount = featureCount;
    }

    public int getFeatureCount() {
        return this.featureCount;
    }

    public void setIterations(int iterations) {
        if (iterations < 1) {
            throw new IllegalArgumentException("Number of iterations must be positive, not " + iterations);
        }
        this.iterations = iterations;
    }

    public int getIterations() {
        return this.iterations;
    }

    public void setNeighbors(int neighbors) {
        if (neighbors < 1) {
            throw new IllegalArgumentException("Number of neighbors must be positive, not " + neighbors);
        }
        this.neighbors = neighbors;
    }

    public int getNeighbors() {
        return this.neighbors;
    }

    public void setDistanceMetric(DistanceMetric dm) {
        this.dm = dm;
    }

    public DistanceMetric getDistanceMetric() {
        return this.dm;
    }
}

