/*
 * Decompiled with CFR 0.152.
 */
package jsat.datatransform.featureselection;

import java.util.Arrays;
import java.util.Iterator;
import java.util.Random;
import java.util.Set;
import jsat.DataSet;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.datatransform.RemoveAttributeTransform;
import jsat.datatransform.featureselection.SFS;
import jsat.regression.RegressionDataSet;
import jsat.regression.Regressor;
import jsat.utils.IntList;
import jsat.utils.IntSet;
import jsat.utils.ListUtils;
import jsat.utils.random.RandomUtil;

public class SBS
extends RemoveAttributeTransform {
    private static final long serialVersionUID = -2516121100148559742L;
    private double maxDecrease;
    private int folds;
    private int minFeatures;
    private int maxFeatures;
    private Object evaluator;

    private SBS(SBS toClone) {
        super(toClone);
        this.maxDecrease = toClone.maxDecrease;
        this.folds = toClone.folds;
        this.minFeatures = toClone.minFeatures;
        this.maxFeatures = toClone.maxFeatures;
        this.evaluator = toClone.evaluator;
    }

    public SBS(int minFeatures, int maxFeatures, Classifier evaluater, double maxDecrease) {
        this(minFeatures, maxFeatures, evaluater, 3, maxDecrease);
    }

    private SBS(int minFeatures, int maxFeatures, Object evaluater, int folds, double maxDecrease) {
        this.setMaxDecrease(maxDecrease);
        this.setMinFeatures(minFeatures);
        this.setMaxFeatures(maxFeatures);
        this.setEvaluator(evaluater);
        this.setFolds(folds);
    }

    public SBS(int minFeatures, int maxFeatures, ClassificationDataSet cds, Classifier evaluater, int folds, double maxDecrease) {
        this(minFeatures, maxFeatures, evaluater, folds, maxDecrease);
        this.search(cds, evaluater, minFeatures, maxFeatures, folds);
    }

    public SBS(int minFeatures, int maxFeatures, Regressor evaluater, double maxDecrease) {
        this(minFeatures, maxFeatures, evaluater, 3, maxDecrease);
    }

    public SBS(int minFeatures, int maxFeatures, RegressionDataSet rds, Regressor evaluater, int folds, double maxDecrease) {
        this(minFeatures, maxFeatures, evaluater, folds, maxDecrease);
        this.search(rds, evaluater, minFeatures, maxFeatures, folds);
    }

    @Override
    public void fit(DataSet data) {
        super.fit(data);
        this.search(data, this.evaluator, this.minFeatures, this.maxFeatures, this.folds);
    }

    private void search(DataSet dataSet, Object learner, int minFeatures, int maxFeatures, int folds) {
        int i;
        Random rand = RandomUtil.getRandom();
        int nF = dataSet.getNumFeatures();
        int nCat = dataSet.getNumCategoricalVars();
        IntSet available = new IntSet();
        ListUtils.addRange(available, 0, nF, 1);
        IntSet catSelected = new IntSet(dataSet.getNumCategoricalVars());
        IntSet numSelected = new IntSet(dataSet.getNumNumericalVars());
        IntSet catToRemove = new IntSet(dataSet.getNumCategoricalVars());
        IntSet numToRemove = new IntSet(dataSet.getNumNumericalVars());
        ListUtils.addRange(catSelected, 0, nCat, 1);
        ListUtils.addRange(numSelected, 0, nF - nCat, 1);
        double[] bestScore = new double[]{Double.POSITIVE_INFINITY};
        while (catSelected.size() + numSelected.size() > minFeatures && SBS.SBSRemoveFeature(available, dataSet, catToRemove, numToRemove, catSelected, numSelected, learner, folds, rand, maxFeatures, bestScore, this.maxDecrease) >= 0) {
        }
        int pos = 0;
        this.catIndexMap = new int[catSelected.size()];
        Iterator iterator = catSelected.iterator();
        while (iterator.hasNext()) {
            i = (Integer)iterator.next();
            this.catIndexMap[pos++] = i;
        }
        Arrays.sort(this.catIndexMap);
        pos = 0;
        this.numIndexMap = new int[numSelected.size()];
        iterator = numSelected.iterator();
        while (iterator.hasNext()) {
            i = (Integer)iterator.next();
            this.numIndexMap[pos++] = i;
        }
        Arrays.sort(this.numIndexMap);
    }

    @Override
    public SBS clone() {
        return new SBS(this);
    }

    public Set<Integer> getSelectedCategorical() {
        return new IntSet(IntList.view(this.catIndexMap, this.catIndexMap.length));
    }

    public Set<Integer> getSelectedNumerical() {
        return new IntSet(IntList.view(this.numIndexMap, this.numIndexMap.length));
    }

    protected static int SBSRemoveFeature(Set<Integer> available, DataSet dataSet, Set<Integer> catToRemove, Set<Integer> numToRemove, Set<Integer> catSelecteed, Set<Integer> numSelected, Object evaluater, int folds, Random rand, int maxFeatures, double[] PbestScore, double maxDecrease) {
        int curBest = -1;
        int nCat = dataSet.getNumCategoricalVars();
        double curBestScore = Double.POSITIVE_INFINITY;
        for (int feature : available) {
            DataSet workOn = dataSet.shallowClone();
            SFS.addFeature(feature, nCat, catToRemove, numToRemove);
            RemoveAttributeTransform remove = new RemoveAttributeTransform(workOn, catToRemove, numToRemove);
            workOn.applyTransform(remove);
            double score = SFS.getScore(workOn, evaluater, folds, rand);
            if (score < curBestScore) {
                curBestScore = score;
                curBest = feature;
            }
            SFS.removeFeature(feature, nCat, catToRemove, numToRemove);
        }
        if (catSelecteed.size() + numSelected.size() > maxFeatures || PbestScore[0] - curBestScore > -maxDecrease) {
            PbestScore[0] = curBestScore;
            SFS.removeFeature(curBest, nCat, catSelecteed, numSelected);
            SFS.addFeature(curBest, nCat, catToRemove, numToRemove);
            available.remove(curBest);
            return curBest;
        }
        return -1;
    }

    public void setMaxDecrease(double maxDecrease) {
        if (maxDecrease < 0.0) {
            throw new IllegalArgumentException("Decarese must be a positive value, not " + maxDecrease);
        }
        this.maxDecrease = maxDecrease;
    }

    public double getMaxDecrease() {
        return this.maxDecrease;
    }

    public void setMinFeatures(int minFeatures) {
        this.minFeatures = minFeatures;
    }

    public int getMinFeatures() {
        return this.minFeatures;
    }

    public void setMaxFeatures(int maxFeatures) {
        this.maxFeatures = maxFeatures;
    }

    public int getMaxFeatures() {
        return this.maxFeatures;
    }

    public void setFolds(int folds) {
        if (folds <= 0) {
            throw new IllegalArgumentException("Number of CV folds must be positive, not " + folds);
        }
        this.folds = folds;
    }

    public int getFolds() {
        return this.folds;
    }

    private void setEvaluator(Object evaluator) {
        this.evaluator = evaluator;
    }
}

