/*
 * Decompiled with CFR 0.152.
 */
package adams.flow.transformer.preparefilebaseddataset;

import adams.core.QuickInfoHelper;
import adams.core.option.OptionHandler;
import adams.data.binning.operation.Wrapping;
import adams.data.splitgenerator.generic.randomization.DefaultRandomization;
import adams.data.splitgenerator.generic.randomization.PassThrough;
import adams.data.splitgenerator.generic.randomsplit.RandomSplitGenerator;
import adams.data.splitgenerator.generic.randomsplit.SplitPair;
import adams.data.splitgenerator.generic.splitter.DefaultSplitter;
import adams.flow.container.FileBasedDatasetContainer;
import adams.flow.transformer.preparefilebaseddataset.AbstractRandomizableFileBasedDatasetPreparation;
import com.github.fracpete.javautils.struct.Struct2;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

public class TrainValidateTestSplit
extends AbstractRandomizableFileBasedDatasetPreparation<String[]> {
    private static final long serialVersionUID = 7027794624748574933L;
    protected double m_TrainPercentage;
    protected double m_ValidatePercentage;
    protected boolean m_PreserveOrder;

    public String globalInfo() {
        return "Generates a train/validate/test split.\nAfter training and validation set have been split off, the remainder is used for the test set.";
    }

    @Override
    public void defineOptions() {
        super.defineOptions();
        this.m_OptionManager.add("train-percentage", "trainPercentage", (Object)0.8, (Number)0.0, (Number)1.0);
        this.m_OptionManager.add("validate-percentage", "validatePercentage", (Object)0.1, (Number)0.0, (Number)1.0);
        this.m_OptionManager.add("preserve-order", "preserveOrder", (Object)false);
    }

    public void setTrainPercentage(double value) {
        if (this.getOptionManager().isValid("percentage", (Number)value)) {
            this.m_TrainPercentage = value;
            this.reset();
        }
    }

    public double getTrainPercentage() {
        return this.m_TrainPercentage;
    }

    public String trainPercentageTipText() {
        return "The percentage of the data to use for the training set.";
    }

    public void setValidatePercentage(double value) {
        if (this.getOptionManager().isValid("percentage", (Number)value)) {
            this.m_ValidatePercentage = value;
            this.reset();
        }
    }

    public double getValidatePercentage() {
        return this.m_ValidatePercentage;
    }

    public String validatePercentageTipText() {
        return "The percentage of the data to use for the validation set.";
    }

    public void setPreserveOrder(boolean value) {
        this.m_PreserveOrder = value;
        this.reset();
    }

    public boolean getPreserveOrder() {
        return this.m_PreserveOrder;
    }

    public String preserveOrderTipText() {
        return "If enabled, the data doesn't get randomized.";
    }

    @Override
    public String getQuickInfo() {
        Object result = QuickInfoHelper.toString((OptionHandler)this, (String)"trainPercentage", (Object)this.m_TrainPercentage, (String)"train: ");
        result = (String)result + QuickInfoHelper.toString((OptionHandler)this, (String)"validatePercentage", (Object)this.m_ValidatePercentage, (String)", val: ");
        result = (String)result + QuickInfoHelper.toString((OptionHandler)this, (String)"preserveOrder", (boolean)this.m_PreserveOrder, (String)"preserve", (String)", ");
        return result;
    }

    @Override
    public Class accepts() {
        return String[].class;
    }

    @Override
    protected String check(String[] data) {
        Object result = super.check(data);
        if (result == null && data.length < 3) {
            result = "At least three files required, provided: " + data.length;
        }
        if (this.m_TrainPercentage + this.m_ValidatePercentage >= 1.0) {
            result = "The sum of percentages for train and validate must be < 1.0: " + this.m_TrainPercentage + " (train) + " + this.m_ValidatePercentage + " (validate) = " + (this.m_TrainPercentage + this.m_ValidatePercentage);
        }
        return result;
    }

    protected Struct2<String[], String[]> split(String[] data, double percentage) {
        List binnable;
        RandomSplitGenerator generator = new RandomSplitGenerator();
        DefaultSplitter splitter = new DefaultSplitter();
        splitter.setPercentage(percentage);
        generator.setSplitter(splitter);
        if (!this.m_PreserveOrder) {
            DefaultRandomization defRand = new DefaultRandomization();
            defRand.setSeed(this.m_Seed);
            defRand.setLoggingLevel(this.m_LoggingLevel);
            generator.setRandomization(defRand);
        } else {
            generator.setRandomization(new PassThrough());
        }
        try {
            binnable = Wrapping.wrap(Arrays.asList(data), (Wrapping.BinValueExtractor)new Wrapping.IndexedBinValueExtractor());
        }
        catch (Exception e) {
            throw new IllegalStateException("Failed to wrap file names in Binnable objects!");
        }
        SplitPair pair = generator.generate(binnable);
        String[] train = Wrapping.unwrap(pair.getTrain().getData()).toArray(new String[0]);
        String[] test = Wrapping.unwrap(pair.getTest().getData()).toArray(new String[0]);
        return new Struct2((Object)train, (Object)test);
    }

    @Override
    protected List<FileBasedDatasetContainer> doPrepare(String[] data) {
        Struct2<String[], String[]> split = this.split(data, this.m_TrainPercentage);
        String[] train = (String[])split.value1;
        String[] remainder = (String[])split.value2;
        split = this.split(remainder, this.m_ValidatePercentage / (1.0 - this.m_TrainPercentage));
        String[] validate = (String[])split.value1;
        String[] test = (String[])split.value2;
        FileBasedDatasetContainer cont = new FileBasedDatasetContainer(train, test, validate, null);
        ArrayList<FileBasedDatasetContainer> result = new ArrayList<FileBasedDatasetContainer>();
        result.add(cont);
        return result;
    }
}

