/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.imbalance;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.DataPointPair;
import jsat.exceptions.FailedToFitException;
import jsat.linear.DenseVector;
import jsat.linear.Vec;
import jsat.linear.VecPaired;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.EuclideanDistance;
import jsat.linear.vectorcollection.DefaultVectorCollectionFactory;
import jsat.linear.vectorcollection.VectorCollection;
import jsat.linear.vectorcollection.VectorCollectionUtils;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.utils.FakeExecutor;
import jsat.utils.IntList;
import jsat.utils.ListUtils;
import jsat.utils.SystemInfo;
import jsat.utils.concurrent.ParallelUtils;
import jsat.utils.random.RandomUtil;

public class SMOTE
implements Classifier,
Parameterized {
    @Parameter.ParameterHolder
    protected Classifier baseClassifier;
    protected DistanceMetric dm;
    protected int smoteNeighbors;
    protected double targetRatio;

    public SMOTE(Classifier baseClassifier) {
        this(baseClassifier, new EuclideanDistance());
    }

    public SMOTE(Classifier baseClassifier, DistanceMetric dm) {
        this(baseClassifier, dm, 1.0);
    }

    public SMOTE(Classifier baseClassifier, DistanceMetric dm, double targetRatio) {
        this(baseClassifier, dm, 5, targetRatio);
    }

    public SMOTE(Classifier baseClassifier, DistanceMetric dm, int smoteNeighbors, double targetRatio) {
        this.setBaseClassifier(baseClassifier);
        this.setDistanceMetric(dm);
        this.setSmoteNeighbors(smoteNeighbors);
        this.setTargetRatio(targetRatio);
    }

    public SMOTE(SMOTE toCopy) {
        this.baseClassifier = toCopy.baseClassifier.clone();
        this.dm = toCopy.dm.clone();
        this.smoteNeighbors = toCopy.smoteNeighbors;
        this.targetRatio = toCopy.targetRatio;
    }

    public void setDistanceMetric(DistanceMetric dm) {
        this.dm = dm;
    }

    public DistanceMetric getDistanceMetric() {
        return this.dm;
    }

    public void setSmoteNeighbors(int smoteNeighbors) {
        if (smoteNeighbors < 1) {
            throw new IllegalArgumentException("number of neighbors considered must be a positive value");
        }
        this.smoteNeighbors = smoteNeighbors;
    }

    public int getSmoteNeighbors() {
        return this.smoteNeighbors;
    }

    public void setTargetRatio(double targetRatio) {
        this.targetRatio = targetRatio;
    }

    public double getTargetRatio() {
        return this.targetRatio;
    }

    public void setBaseClassifier(Classifier baseClassifier) {
        this.baseClassifier = baseClassifier;
    }

    public Classifier getBaseClassifier() {
        return this.baseClassifier;
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        return this.baseClassifier.classify(data);
    }

    @Override
    public void trainC(ClassificationDataSet dataSet, ExecutorService threadPool) {
        int i;
        if (dataSet.getNumCategoricalVars() != 0) {
            throw new FailedToFitException("SMOTE only works with numeric-only feature values");
        }
        if (threadPool == null) {
            threadPool = new FakeExecutor();
        }
        List<Vec> vAll = dataSet.getDataVectors();
        IntList[] classIndex = new IntList[dataSet.getClassSize()];
        for (i = 0; i < classIndex.length; ++i) {
            classIndex[i] = new IntList();
        }
        for (i = 0; i < dataSet.getSampleSize(); ++i) {
            classIndex[dataSet.getDataPointCategory(i)].add(i);
        }
        double[] priors = dataSet.getPriors();
        Vec ratios = DenseVector.toDenseVec(priors).clone();
        int majorityNum = (int)((double)dataSet.getSampleSize() * ratios.max());
        ratios.mutableDivide(ratios.max());
        final ArrayList synthetics = new ArrayList();
        Iterator iterator = ListUtils.range(0, dataSet.getClassSize()).iterator();
        while (iterator.hasNext()) {
            final int classID = (Integer)iterator.next();
            final int samplesNeeded = (int)((double)majorityNum * this.targetRatio - (double)classIndex[classID].size());
            if (samplesNeeded <= 0) continue;
            final ArrayList<Vec> V_id = new ArrayList<Vec>();
            Iterator iterator2 = classIndex[classID].iterator();
            while (iterator2.hasNext()) {
                int i2 = (Integer)iterator2.next();
                V_id.add(vAll.get(i2));
            }
            VectorCollection VC_id = new DefaultVectorCollectionFactory<Vec>().getVectorCollection(V_id, this.dm, threadPool);
            final List nns_id = VectorCollectionUtils.allNearestNeighbors(VC_id, V_id, this.smoteNeighbors + 1, threadPool);
            final CountDownLatch latch = new CountDownLatch(SystemInfo.LogicalCores);
            Iterator iterator3 = ListUtils.range(0, SystemInfo.LogicalCores).iterator();
            while (iterator3.hasNext()) {
                final int threadID = (Integer)iterator3.next();
                threadPool.submit(new Runnable(){

                    /*
                     * WARNING - Removed try catching itself - possible behaviour change.
                     */
                    @Override
                    public void run() {
                        Random rand = RandomUtil.getRandom();
                        ArrayList<DataPoint> local_new = new ArrayList<DataPoint>();
                        for (int i = ParallelUtils.getStartBlock(samplesNeeded, threadID); i < ParallelUtils.getEndBlock(samplesNeeded, threadID); ++i) {
                            int sampleIndex = i % V_id.size();
                            int nn = rand.nextInt(SMOTE.this.smoteNeighbors) + 1;
                            VecPaired vec_nn = (VecPaired)((List)nns_id.get(sampleIndex)).get(nn);
                            double gap = rand.nextDouble();
                            Vec newVal = ((Vec)V_id.get(sampleIndex)).clone();
                            newVal.mutableMultiply(gap + 1.0);
                            newVal.mutableAdd(gap, vec_nn);
                            local_new.add(new DataPoint(newVal));
                        }
                        List list = synthetics;
                        synchronized (list) {
                            for (DataPoint v : local_new) {
                                synthetics.add(new DataPointPair<Integer>(v, classID));
                            }
                        }
                        latch.countDown();
                    }
                });
            }
            try {
                latch.await();
            }
            catch (InterruptedException ex) {
                Logger.getLogger(SMOTE.class.getName()).log(Level.SEVERE, null, ex);
            }
        }
        ClassificationDataSet newDataSet = new ClassificationDataSet(ListUtils.mergedView(synthetics, dataSet.getAsDPPList()), dataSet.getPredicting());
        this.baseClassifier.trainC(newDataSet, threadPool);
    }

    @Override
    public void trainC(ClassificationDataSet dataSet) {
        this.trainC(dataSet, new FakeExecutor());
    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }

    @Override
    public SMOTE clone() {
        return new SMOTE(this);
    }

    @Override
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override
    public Parameter getParameter(String paramName) {
        return Parameter.toParameterMap(this.getParameters()).get(paramName);
    }
}

