/*
 * Decompiled with CFR 0.152.
 */
package weka.knowledgeflow.steps;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import weka.core.Instances;
import weka.core.WekaException;
import weka.knowledgeflow.Data;
import weka.knowledgeflow.steps.BaseStep;
import weka.knowledgeflow.steps.KFStep;

@KFStep(name="TrainTestSplitMaker", category="Evaluation", toolTipText="Create a random train/test split", iconPath="weka/gui/knowledgeflow/icons/TrainTestSplitMaker.gif")
public class TrainTestSplitMaker
extends BaseStep {
    private static final long serialVersionUID = 7685026723199727685L;
    protected String m_trainPercentageS = "66";
    protected String m_seedS = "1";
    protected double m_trainPercentage = 66.0;
    protected long m_seed = 1L;

    @Override
    public String globalInfo() {
        return "A step that randomly splits incoming data into a training and test set";
    }

    public String trainPercentTipText() {
        return "The percentage of data to go into the training set";
    }

    public void setTrainPercent(String percent) {
        this.m_trainPercentageS = percent;
    }

    public String getTrainPercent() {
        return this.m_trainPercentageS;
    }

    public String seedTipText() {
        return "The randomization seed";
    }

    public void setSeed(String seed) {
        this.m_seedS = seed;
    }

    public String getSeed() {
        return this.m_seedS;
    }

    @Override
    public void stepInit() throws WekaException {
        String seed = this.getStepManager().environmentSubstitute(this.getSeed());
        try {
            this.m_seed = Long.parseLong(seed);
        }
        catch (NumberFormatException ex) {
            this.getStepManager().logWarning("Unable to parse seed value: " + seed);
        }
        String tP = this.getStepManager().environmentSubstitute(this.getTrainPercent());
        try {
            this.m_trainPercentage = Double.parseDouble(tP);
        }
        catch (NumberFormatException ex) {
            this.getStepManager().logWarning("Unable to parse train percentage value: " + tP);
        }
    }

    @Override
    public void processIncoming(Data data) throws WekaException {
        this.getStepManager().processing();
        String incomingConnName = data.getConnectionName();
        Instances dataSet = (Instances)data.getPayloadElement(incomingConnName);
        if (dataSet == null) {
            throw new WekaException("Incoming instances should not be null!");
        }
        this.getStepManager().logBasic("Creating train/test split");
        this.getStepManager().statusMessage("Creating train/test split");
        dataSet.randomize(new Random(this.m_seed));
        int trainSize = (int)Math.round((double)dataSet.numInstances() * this.m_trainPercentage / 100.0);
        int testSize = dataSet.numInstances() - trainSize;
        Instances train = new Instances(dataSet, 0, trainSize);
        Instances test = new Instances(dataSet, trainSize, testSize);
        Data trainData = new Data("trainingSet");
        trainData.setPayloadElement("trainingSet", train);
        trainData.setPayloadElement("aux_set_num", 1);
        trainData.setPayloadElement("aux_max_set_num", 1);
        Data testData = new Data("testSet");
        testData.setPayloadElement("testSet", test);
        testData.setPayloadElement("aux_set_num", 1);
        testData.setPayloadElement("aux_max_set_num", 1);
        if (!this.isStopRequested()) {
            this.getStepManager().outputData(trainData, testData);
        }
        this.getStepManager().finished();
    }

    @Override
    public List<String> getIncomingConnectionTypes() {
        if (this.getStepManager().numIncomingConnections() > 0) {
            return new ArrayList<String>();
        }
        return Arrays.asList("dataSet", "trainingSet", "testSet");
    }

    @Override
    public List<String> getOutgoingConnectionTypes() {
        return this.getStepManager().numIncomingConnections() > 0 ? Arrays.asList("trainingSet", "testSet") : new ArrayList<String>();
    }

    @Override
    public Instances outputStructureForConnectionType(String connectionName) throws WekaException {
        if (!connectionName.equals("trainingSet") && !connectionName.equals("testSet") || this.getStepManager().numIncomingConnections() == 0) {
            return null;
        }
        Instances strucForDatasetCon = this.getStepManager().getIncomingStructureForConnectionType("dataSet");
        if (strucForDatasetCon != null) {
            return strucForDatasetCon;
        }
        Instances strucForTestsetCon = this.getStepManager().getIncomingStructureForConnectionType("testSet");
        if (strucForTestsetCon != null) {
            return strucForTestsetCon;
        }
        Instances strucForTrainingCon = this.getStepManager().getIncomingStructureForConnectionType("trainingSet");
        if (strucForTestsetCon != null) {
            return strucForTrainingCon;
        }
        return null;
    }
}

