/*
 * Decompiled with CFR 0.152.
 */
package weka.knowledgeflow.steps;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import weka.core.Instances;
import weka.core.OptionMetadata;
import weka.core.WekaException;
import weka.knowledgeflow.Data;
import weka.knowledgeflow.steps.BaseStep;
import weka.knowledgeflow.steps.KFStep;

@KFStep(name="CrossValidationFoldMaker", category="Evaluation", toolTipText="A Step that creates stratified cross-validation folds from incoming data", iconPath="weka/gui/knowledgeflow/icons/CrossValidationFoldMaker.gif")
public class CrossValidationFoldMaker
extends BaseStep {
    private static final long serialVersionUID = 6090713408437825355L;
    protected boolean m_preserveOrder;
    protected String m_numFoldsS = "10";
    protected String m_seedS = "1";
    protected int m_numFolds = 10;
    protected long m_seed = 1L;

    @OptionMetadata(displayName="Number of folds", description="THe number of folds to create", displayOrder=0)
    public void setNumFolds(String folds) {
        this.m_numFoldsS = folds;
    }

    public String getNumFolds() {
        return this.m_numFoldsS;
    }

    public String preserveOrderTipText() {
        return "Preserve the order of the training data (rather than randomly shuffling) before creating folds";
    }

    @OptionMetadata(displayName="Preserve instances order", description="Preserve the order of instances rather than randomly shuffling", displayOrder=1)
    public void setPreserveOrder(boolean preserve) {
        this.m_preserveOrder = preserve;
    }

    public boolean getPreserveOrder() {
        return this.m_preserveOrder;
    }

    @OptionMetadata(displayName="Random seed", description="The random seed to use for shuffling", displayOrder=3)
    public void setSeed(String seed) {
        this.m_seedS = seed;
    }

    public String getSeed() {
        return this.m_seedS;
    }

    @Override
    public void stepInit() throws WekaException {
        String seed = this.getStepManager().environmentSubstitute(this.getSeed());
        try {
            this.m_seed = Long.parseLong(seed);
        }
        catch (NumberFormatException ex) {
            this.getStepManager().logWarning("Unable to parse seed value: " + seed);
        }
        String folds = this.getStepManager().environmentSubstitute(this.getNumFolds());
        try {
            this.m_numFolds = Integer.parseInt(folds);
        }
        catch (NumberFormatException e) {
            this.getStepManager().logWarning("Unable to parse number of folds value: " + folds);
        }
    }

    @Override
    public void processIncoming(Data data) throws WekaException {
        this.getStepManager().processing();
        String incomingConnName = data.getConnectionName();
        Instances dataSet = (Instances)data.getPayloadElement(incomingConnName);
        if (dataSet == null) {
            throw new WekaException("Incoming instances should not be null!");
        }
        dataSet = new Instances(dataSet);
        this.getStepManager().logBasic("Creating cross-validation folds");
        this.getStepManager().statusMessage("Creating cross-validation folds");
        Random random = new Random(this.m_seed);
        if (!this.getPreserveOrder()) {
            dataSet.randomize(random);
        }
        if (dataSet.classIndex() >= 0 && dataSet.attribute(dataSet.classIndex()).isNominal() && !this.getPreserveOrder()) {
            this.getStepManager().logBasic("Stratifying data");
            dataSet.stratify(this.m_numFolds);
        }
        for (int i = 0; i < this.m_numFolds && !this.isStopRequested(); ++i) {
            Instances train = !this.m_preserveOrder ? dataSet.trainCV(this.m_numFolds, i, random) : dataSet.trainCV(this.m_numFolds, i);
            Instances test = dataSet.testCV(this.m_numFolds, i);
            Data trainData = new Data("trainingSet");
            trainData.setPayloadElement("trainingSet", train);
            trainData.setPayloadElement("aux_set_num", i + 1);
            trainData.setPayloadElement("aux_max_set_num", this.m_numFolds);
            Data testData = new Data("testSet");
            testData.setPayloadElement("testSet", test);
            testData.setPayloadElement("aux_set_num", i + 1);
            testData.setPayloadElement("aux_max_set_num", this.m_numFolds);
            if (this.isStopRequested()) continue;
            this.getStepManager().outputData(trainData, testData);
        }
        this.getStepManager().finished();
    }

    @Override
    public List<String> getIncomingConnectionTypes() {
        if (this.getStepManager().numIncomingConnections() > 0) {
            return new ArrayList<String>();
        }
        return Arrays.asList("dataSet", "trainingSet", "testSet");
    }

    @Override
    public List<String> getOutgoingConnectionTypes() {
        return this.getStepManager().numIncomingConnections() > 0 ? Arrays.asList("trainingSet", "testSet") : new ArrayList<String>();
    }

    @Override
    public Instances outputStructureForConnectionType(String connectionName) throws WekaException {
        if (!connectionName.equals("trainingSet") && !connectionName.equals("testSet") || this.getStepManager().numIncomingConnections() == 0) {
            return null;
        }
        Instances strucForDatasetCon = this.getStepManager().getIncomingStructureForConnectionType("dataSet");
        if (strucForDatasetCon != null) {
            return strucForDatasetCon;
        }
        Instances strucForTestsetCon = this.getStepManager().getIncomingStructureForConnectionType("testSet");
        if (strucForTestsetCon != null) {
            return strucForTestsetCon;
        }
        Instances strucForTrainingCon = this.getStepManager().getIncomingStructureForConnectionType("trainingSet");
        if (strucForTrainingCon != null) {
            return strucForTrainingCon;
        }
        return null;
    }
}

