/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.imbalance;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.DataPointPair;
import jsat.classifiers.imbalance.SMOTE;
import jsat.exceptions.FailedToFitException;
import jsat.linear.DenseVector;
import jsat.linear.Vec;
import jsat.linear.VecPaired;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.EuclideanDistance;
import jsat.linear.vectorcollection.DefaultVectorCollectionFactory;
import jsat.linear.vectorcollection.VectorCollection;
import jsat.linear.vectorcollection.VectorCollectionUtils;
import jsat.utils.FakeExecutor;
import jsat.utils.IntList;
import jsat.utils.ListUtils;
import jsat.utils.SystemInfo;
import jsat.utils.concurrent.ParallelUtils;
import jsat.utils.random.RandomUtil;

public class BorderlineSMOTE
extends SMOTE {
    private boolean majorityInterpolation;

    public BorderlineSMOTE(Classifier baseClassifier) {
        this(baseClassifier, false);
    }

    public BorderlineSMOTE(Classifier baseClassifier, boolean majorityInterpolation) {
        this(baseClassifier, (DistanceMetric)new EuclideanDistance(), majorityInterpolation);
    }

    public BorderlineSMOTE(Classifier baseClassifier, DistanceMetric dm, boolean majorityInterpolation) {
        this(baseClassifier, dm, 1.0, majorityInterpolation);
    }

    public BorderlineSMOTE(Classifier baseClassifier, DistanceMetric dm, double targetRatio, boolean majorityInterpolation) {
        this(baseClassifier, dm, 5, targetRatio, majorityInterpolation);
    }

    public BorderlineSMOTE(Classifier baseClassifier, DistanceMetric dm, int smoteNeighbors, double targetRatio, boolean majorityInterpolation) {
        super(baseClassifier, dm, smoteNeighbors, targetRatio);
        this.setMajorityInterpolation(majorityInterpolation);
    }

    public BorderlineSMOTE(BorderlineSMOTE toCopy) {
        super(toCopy);
        this.majorityInterpolation = toCopy.majorityInterpolation;
    }

    public void setMajorityInterpolation(boolean majorityInterpolation) {
        this.majorityInterpolation = majorityInterpolation;
    }

    public boolean isMajorityInterpolation() {
        return this.majorityInterpolation;
    }

    @Override
    public void trainC(ClassificationDataSet dataSet, ExecutorService threadPool) {
        int i;
        if (dataSet.getNumCategoricalVars() != 0) {
            throw new FailedToFitException("SMOTE only works with numeric-only feature values");
        }
        if (threadPool == null) {
            threadPool = new FakeExecutor();
        }
        List<Vec> vAll = dataSet.getDataVectors();
        IntList[] classIndex = new IntList[dataSet.getClassSize()];
        for (i = 0; i < classIndex.length; ++i) {
            classIndex[i] = new IntList();
        }
        for (i = 0; i < dataSet.getSampleSize(); ++i) {
            classIndex[dataSet.getDataPointCategory(i)].add(i);
        }
        double[] priors = dataSet.getPriors();
        Vec ratios = DenseVector.toDenseVec(priors).clone();
        int majorityNum = (int)((double)dataSet.getSampleSize() * ratios.max());
        ratios.mutableDivide(ratios.max());
        final ArrayList synthetics = new ArrayList();
        ArrayList<VecPaired<Vec, Integer>> allVecsWithClass = new ArrayList<VecPaired<Vec, Integer>>(vAll.size());
        for (int i2 = 0; i2 < vAll.size(); ++i2) {
            allVecsWithClass.add(new VecPaired<Vec, Integer>(vAll.get(i2), dataSet.getDataPointCategory(i2)));
        }
        VectorCollection VC_all = new DefaultVectorCollectionFactory<VecPaired<Vec, Integer>>().getVectorCollection(allVecsWithClass, this.dm, threadPool);
        Iterator iterator = ListUtils.range(0, dataSet.getClassSize()).iterator();
        while (iterator.hasNext()) {
            Object nns_id_i;
            final int classID = (Integer)iterator.next();
            final int samplesNeeded = (int)((double)majorityNum * this.targetRatio - (double)classIndex[classID].size());
            if (samplesNeeded <= 0) continue;
            final ArrayList<Vec> V_id = new ArrayList<Vec>();
            Iterator iterator2 = classIndex[classID].iterator();
            while (iterator2.hasNext()) {
                int i3 = (Integer)iterator2.next();
                V_id.add(vAll.get(i3));
            }
            VectorCollection VC_id = new DefaultVectorCollectionFactory<Vec>().getVectorCollection(V_id, this.dm, threadPool);
            List all_nns_ID = VectorCollectionUtils.allNearestNeighbors(VC_all, V_id, this.smoteNeighbors + 1, threadPool);
            final ArrayList otherClassSamples = new ArrayList();
            if (this.majorityInterpolation) {
                for (List tmp : all_nns_ID) {
                    otherClassSamples.add(new ArrayList(this.smoteNeighbors));
                }
            }
            final IntList danger_id = new IntList();
            for (int i4 = 0; i4 < VC_id.size(); ++i4) {
                int same_class = 0;
                nns_id_i = all_nns_ID.get(i4);
                for (int j = 1; j < nns_id_i.size(); ++j) {
                    if (classID == (Integer)((VecPaired)nns_id_i.get(j).getVector()).getPair()) {
                        ++same_class;
                        continue;
                    }
                    if (!this.majorityInterpolation) continue;
                    ((List)otherClassSamples.get(i4)).add(((VecPaired)((VecPaired)nns_id_i.get(j)).getVector()).getVector());
                }
                double sOm = 1.0 - (double)same_class / (double)this.smoteNeighbors;
                if (!(0.5 <= sOm) || !(sOm < 1.0)) continue;
                danger_id.add(i4);
            }
            final List nns_id = VectorCollectionUtils.allNearestNeighbors(VC_id, V_id, this.smoteNeighbors + 1, threadPool);
            final CountDownLatch latch = new CountDownLatch(SystemInfo.LogicalCores);
            nns_id_i = ListUtils.range(0, SystemInfo.LogicalCores).iterator();
            while (nns_id_i.hasNext()) {
                final int threadID = (Integer)nns_id_i.next();
                threadPool.submit(new Runnable(){

                    /*
                     * WARNING - Removed try catching itself - possible behaviour change.
                     */
                    @Override
                    public void run() {
                        Random rand = RandomUtil.getRandom();
                        ArrayList<DataPoint> local_new = new ArrayList<DataPoint>();
                        for (int i = ParallelUtils.getStartBlock(samplesNeeded, threadID); i < ParallelUtils.getEndBlock(samplesNeeded, threadID); ++i) {
                            Vec vec_nn;
                            boolean useOtherClass;
                            int sampleIndex = danger_id.isEmpty() ? i % V_id.size() : danger_id.getI(i % danger_id.size());
                            boolean bl = useOtherClass = rand.nextBoolean() && BorderlineSMOTE.this.majorityInterpolation && !danger_id.isEmpty();
                            if (useOtherClass) {
                                List candidates = (List)otherClassSamples.get(sampleIndex);
                                vec_nn = (Vec)candidates.get(rand.nextInt(candidates.size()));
                            } else {
                                int nn = rand.nextInt(BorderlineSMOTE.this.smoteNeighbors) + 1;
                                vec_nn = (Vec)((List)nns_id.get(sampleIndex)).get(nn);
                            }
                            double gap = rand.nextDouble();
                            if (useOtherClass) {
                                gap /= 2.0;
                            }
                            Vec newVal = ((Vec)V_id.get(sampleIndex)).clone();
                            newVal.mutableMultiply(gap + 1.0);
                            newVal.mutableAdd(gap, vec_nn);
                            local_new.add(new DataPoint(newVal));
                        }
                        List list = synthetics;
                        synchronized (list) {
                            for (DataPoint v : local_new) {
                                synthetics.add(new DataPointPair<Integer>(v, classID));
                            }
                        }
                        latch.countDown();
                    }
                });
            }
            try {
                latch.await();
            }
            catch (InterruptedException ex) {
                Logger.getLogger(BorderlineSMOTE.class.getName()).log(Level.SEVERE, null, ex);
            }
        }
        ClassificationDataSet newDataSet = new ClassificationDataSet(ListUtils.mergedView(synthetics, dataSet.getAsDPPList()), dataSet.getPredicting());
        this.baseClassifier.trainC(newDataSet, threadPool);
    }

    @Override
    public BorderlineSMOTE clone() {
        return new BorderlineSMOTE(this);
    }
}

