/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.meta.multisearch;

import java.io.File;
import java.io.Serializable;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;
import java.util.concurrent.Future;
import weka.classifiers.Classifier;
import weka.classifiers.meta.multisearch.AbstractMultiThreadedSearch;
import weka.classifiers.meta.multisearch.AbstractSearch;
import weka.classifiers.meta.multisearch.Performance;
import weka.classifiers.meta.multisearch.PerformanceComparator;
import weka.core.Instances;
import weka.core.Option;
import weka.core.Utils;
import weka.core.converters.ConverterUtils;
import weka.core.setupgenerator.Point;
import weka.core.setupgenerator.Space;
import weka.filters.Filter;
import weka.filters.unsupervised.instance.Resample;

public class RandomSearch
extends AbstractMultiThreadedSearch {
    private static final long serialVersionUID = 2542453917013899104L;
    protected double m_SampleSize = 100.0;
    protected int m_SearchSpaceNumFolds = 2;
    protected File m_SearchSpaceTestSet = new File(".");
    protected Instances m_SearchSpaceTestInst;
    protected int m_NumIterations = 100;
    protected int m_RandomSeed = 1;

    @Override
    public String globalInfo() {
        return "Performs a search of an arbitrary number of parameters of a classifier and chooses the best pair found for the actual filtering and training.\n";
    }

    @Override
    public Enumeration listOptions() {
        Vector<Object> result = new Vector<Object>();
        result.addElement(new Option("\tThe size (in percent) of the sample to search the inital space with.\n\t(default: 100)", "sample-size", 1, "-sample-size <num>"));
        result.addElement(new Option("\tThe number of cross-validation folds for the search space.\n\tNumbers smaller than 2 turn off cross-validation and\n\tjust perform evaluation on the training set.\n\t(default: 2)", "num-folds", 1, "-num-folds <num>"));
        result.addElement(new Option("\tThe (optional) test set to use for the search space.\n\tGets ignored if pointing to a file. Overrides cross-validation.\n\t(default: .)", "test-set", 1, "-test-set <filename>"));
        result.addElement(new Option("\tThe number parameter settings that are tried (i.e., number of points in the search space are checked).\n\t(default: 100)", "num-iterations", 1, "-num-iterations <num>"));
        result.addElement(new Option("\tThe random seed", "seed", 1, "-S <num>"));
        Enumeration en = super.listOptions();
        while (en.hasMoreElements()) {
            result.addElement(en.nextElement());
        }
        return result.elements();
    }

    @Override
    public String[] getOptions() {
        Vector<String> result = new Vector<String>();
        result.add("-sample-size");
        result.add("" + this.getSampleSizePercent());
        result.add("-num-folds");
        result.add("" + this.getSearchSpaceNumFolds());
        result.add("-test-set");
        result.add("" + this.getSearchSpaceTestSet());
        result.add("-num-iterations");
        result.add("" + this.getNumIterations());
        result.add("-S");
        result.add("" + this.getRandomSeed());
        String[] options = super.getOptions();
        for (int i = 0; i < options.length; ++i) {
            result.add(options[i]);
        }
        return result.toArray(new String[result.size()]);
    }

    @Override
    public void setOptions(String[] options) throws Exception {
        String tmpStr = Utils.getOption((String)"sample-size", (String[])options);
        if (tmpStr.length() != 0) {
            this.setSampleSizePercent(Double.parseDouble(tmpStr));
        } else {
            this.setSampleSizePercent(100.0);
        }
        tmpStr = Utils.getOption((String)"num-folds", (String[])options);
        if (tmpStr.length() != 0) {
            this.setSearchSpaceNumFolds(Integer.parseInt(tmpStr));
        } else {
            this.setSearchSpaceNumFolds(2);
        }
        tmpStr = Utils.getOption((String)"test-set", (String[])options);
        if (tmpStr.length() != 0) {
            this.setSearchSpaceTestSet(new File(tmpStr));
        } else {
            this.setSearchSpaceTestSet(new File(System.getProperty("user.dir")));
        }
        tmpStr = Utils.getOption((String)"num-iterations", (String[])options);
        if (tmpStr.length() != 0) {
            this.setNumIterations(Integer.parseInt(tmpStr));
        } else {
            this.setNumIterations(100);
        }
        tmpStr = Utils.getOption((String)"S", (String[])options);
        if (tmpStr.length() != 0) {
            this.setRandomSeed(Integer.parseInt(tmpStr));
        } else {
            this.setRandomSeed(1);
        }
        super.setOptions(options);
    }

    public String sampleSizePercentTipText() {
        return "The sample size (in percent) to use in the search.";
    }

    public double getSampleSizePercent() {
        return this.m_SampleSize;
    }

    public void setSampleSizePercent(double value) {
        this.m_SampleSize = value;
    }

    public String searchSpaceNumFoldsTipText() {
        return "The number of cross-validation folds when evaluating the search space; values smaller than 2 turn cross-validation off and simple evaluation on the training set is performed.";
    }

    public int getSearchSpaceNumFolds() {
        return this.m_SearchSpaceNumFolds;
    }

    public void setSearchSpaceNumFolds(int value) {
        this.m_SearchSpaceNumFolds = value;
    }

    public String searchSpaceTestSetTipText() {
        return "The (optional) test set to use for evaluating the search space; overrides cross-validation; gets ignored if pointing to a directory.";
    }

    public File getSearchSpaceTestSet() {
        return this.m_SearchSpaceTestSet;
    }

    public void setSearchSpaceTestSet(File value) {
        this.m_SearchSpaceTestSet = value;
    }

    public String numIterationsTipText() {
        return "The number parameter settings that are tried; ";
    }

    public int getNumIterations() {
        return this.m_NumIterations;
    }

    public void setNumIterations(int value) {
        this.m_NumIterations = value;
    }

    public String randomSeedTipText() {
        return "The seed used for randomization";
    }

    public int getRandomSeed() {
        return this.m_RandomSeed;
    }

    public void setRandomSeed(int value) {
        this.m_RandomSeed = value;
    }

    protected Performance determineBestInSpace(Space space, Instances train, Instances test, int folds, Random random, boolean postClean) throws Exception {
        int i;
        this.m_Performances.clear();
        if (folds >= 2) {
            this.log("Determining best values with " + folds + "-fold CV in space:\n" + space + "\n");
        } else {
            this.log("Determining best values with evaluation on training set in space:\n" + space + "\n");
        }
        ArrayList<Point<Object>> enm = Collections.list(space.values());
        this.m_NumSetups = Math.min(space.size(), this.m_NumIterations);
        Collections.shuffle(enm, random);
        int classLabel = train.classAttribute().isNominal() ? this.m_Owner.getClassLabelIndex(train.classAttribute().numValues()) : -1;
        ArrayList tasks = new ArrayList();
        ArrayList results = new ArrayList();
        for (i = 0; i < this.m_NumSetups; ++i) {
            Point values = (Point)enm.get(i);
            if (this.m_Cache.isCached(folds, values)) {
                Performance performance = this.m_Cache.get(folds, values);
                this.m_Performances.add(performance);
                this.m_Trace.add(new AbstractMap.SimpleEntry<Integer, Performance>(folds, performance));
                this.log(performance + ": cached=true");
                continue;
            }
            Object newTask = this.m_Owner.getFactory().newTask(this.m_Owner, train, test, this.m_Owner.getGenerator(), values, folds, this.m_Owner.getEvaluation().getSelectedTag().getID(), classLabel);
            results.add(this.m_ExecutorPool.submit(newTask));
        }
        try {
            for (i = 0; i < results.size(); ++i) {
                if (((Boolean)((Future)results.get(i)).get()).booleanValue()) continue;
                System.err.println("Execution of evaluation thread failed:\n" + tasks.get(i));
                throw new IllegalStateException("Execution of evaluation thread failed:\n" + tasks.get(i));
            }
        }
        catch (Exception e) {
            System.err.println("Thread-based execution of evaluation tasks failed!");
            e.printStackTrace();
            throw new IllegalStateException("Thread-based execution of evaluation tasks failed!", e);
        }
        Collections.sort(this.m_Performances, new PerformanceComparator(this.m_Owner.getEvaluation().getSelectedTag().getID(), this.m_Owner.getMetrics()));
        Performance result = (Performance)this.m_Performances.firstElement();
        this.m_UniformPerformance = true;
        Performance p1 = (Performance)this.m_Performances.get(0);
        for (i = 1; i < this.m_Performances.size(); ++i) {
            Performance p2 = (Performance)this.m_Performances.get(i);
            if (p2.getPerformance(this.m_Owner.getEvaluation().getSelectedTag().getID()) == p1.getPerformance(this.m_Owner.getEvaluation().getSelectedTag().getID())) continue;
            this.m_UniformPerformance = false;
            break;
        }
        if (this.m_UniformPerformance) {
            this.log("All performances are the same!");
        }
        this.logPerformances(space, this.m_Performances);
        this.log("\nBest performance:\n" + this.m_Performances.firstElement());
        if (postClean) {
            this.m_Performances.clear();
        }
        return result;
    }

    public Vector<Performance> getPerformances() {
        return this.m_Performances;
    }

    protected Performance findBest(Instances inst) throws Exception {
        Instances sample;
        Random random = new Random(this.m_RandomSeed);
        this.log("Step 1:\n");
        if (this.getSampleSizePercent() == 100.0) {
            sample = inst;
        } else {
            this.log("Generating sample (" + this.getSampleSizePercent() + "%)");
            Resample resample = new Resample();
            resample.setRandomSeed(this.retrieveOwner().getSeed());
            resample.setSampleSizePercent(this.getSampleSizePercent());
            resample.setInputFormat(inst);
            sample = Filter.useFilter((Instances)inst, (Filter)resample);
        }
        this.m_UniformPerformance = false;
        this.log("\n=== Search space - Start ===");
        Performance result = this.determineBestInSpace(this.m_Space, sample, this.m_SearchSpaceTestInst, this.m_SearchSpaceNumFolds, random, true);
        this.log("\nResult: " + result + "\n");
        this.log("=== Search space - End ===\n");
        Point<Object> evals = this.m_Owner.getGenerator().evaluate(result.getValues());
        Classifier cls = (Classifier)this.m_Owner.getGenerator().setup((Serializable)this.m_Owner.getClassifier(), evals);
        this.log("Classifier: " + this.getCommandline(cls));
        return result;
    }

    protected void loadTestData(Instances data) throws Exception {
        this.m_SearchSpaceTestInst = null;
        if (this.m_SearchSpaceTestSet.exists() && !this.m_SearchSpaceTestSet.isDirectory()) {
            this.m_SearchSpaceTestInst = ConverterUtils.DataSource.read((String)this.m_SearchSpaceTestSet.getAbsolutePath());
            this.m_SearchSpaceTestInst.setClassIndex(data.classIndex());
            String msg = data.equalHeadersMsg(this.m_SearchSpaceTestInst);
            if (msg != null) {
                throw new IllegalArgumentException("Test set for search space not compatible with training dta:\n" + msg);
            }
            this.m_SearchSpaceTestInst.deleteWithMissingClass();
            this.log("Using test set for search space: " + this.m_SearchSpaceTestSet);
        }
    }

    @Override
    public AbstractSearch.SearchResult doSearch(Instances data) throws Exception {
        this.loadTestData(data);
        Performance performance = this.findBest(new Instances(data));
        Point<Object> evals = this.m_Owner.getGenerator().evaluate(performance.getValues());
        AbstractSearch.SearchResult result = new AbstractSearch.SearchResult();
        result.classifier = (Classifier)this.m_Owner.getGenerator().setup((Serializable)this.m_Owner.getClassifier(), evals);
        result.performance = performance;
        result.values = evals;
        return result;
    }
}

