package adams.flow.transformer;

import adams.core.ObjectCopyHelper;
import adams.core.QuickInfoHelper;
import adams.core.Randomizable;
import adams.core.Stoppable;
import adams.data.weka.InstancesViewCreator;
import adams.flow.container.WekaTrainTestSetContainer;
import adams.flow.core.Token;
import java.util.Hashtable;
import weka.classifiers.CrossValidationFoldGenerator;
import weka.classifiers.DefaultCrossValidationFoldGenerator;
import weka.core.Instances;

/* loaded from: input_file:adams/flow/transformer/WekaCrossValidationSplit.class */
public class WekaCrossValidationSplit extends AbstractTransformer implements Randomizable, InstancesViewCreator {
    private static final long serialVersionUID = 4026105903223741240L;
    public static final String BACKUP_GENERATOR = "generator";
    protected long m_Seed;
    protected int m_Folds;
    protected String m_RelationName;
    protected boolean m_CreateView;
    protected CrossValidationFoldGenerator m_Generator;
    protected transient CrossValidationFoldGenerator m_ActualGenerator;

    public String globalInfo() {
        return "Generates train/test pairs like during a cross-validation run. It is possible to generate pairs for leave-one-out cross-validation (LOOCV) as well.\nIt is essential that a class attribute is set. In case of a nominal class attribute, the data gets stratified automatically.\nEach of the pairs gets forwarded as a container. The training set can be accessed in the container with 'Train' and the test set with 'Test'.";
    }

    public void defineOptions() {
        super.defineOptions();
        this.m_OptionManager.add("seed", "seed", 1L);
        this.m_OptionManager.add("folds", "folds", 10);
        this.m_OptionManager.add("relation", "relationName", "@");
        this.m_OptionManager.add("create-view", "createView", false);
        this.m_OptionManager.add("generator", "generator", new DefaultCrossValidationFoldGenerator());
    }

    public String getQuickInfo() {
        String str = (QuickInfoHelper.toString(this, "folds", Integer.valueOf(this.m_Folds), "folds: ") + QuickInfoHelper.toString(this, "seed", Long.valueOf(this.m_Seed), ", seed: ")) + QuickInfoHelper.toString(this, "relationName", this.m_RelationName, ", relation: ");
        String quickInfoHelper = QuickInfoHelper.toString(this, "createView", this.m_CreateView, ", view only");
        if (quickInfoHelper != null) {
            str = str + quickInfoHelper;
        }
        return str;
    }

    public Class[] accepts() {
        return new Class[]{Instances.class};
    }

    public Class[] generates() {
        return new Class[]{WekaTrainTestSetContainer.class};
    }

    public void setSeed(long j) {
        this.m_Seed = j;
        reset();
    }

    public long getSeed() {
        return this.m_Seed;
    }

    public String seedTipText() {
        return "The seed value for the randomization; overrides the value defined by the fold generator scheme.";
    }

    public void setFolds(int i) {
        if (i < 2 && i != -1) {
            getLogger().severe("Folds must be >=2 or -1 for LOOCV (provided: " + i + ")!");
        } else {
            this.m_Folds = i;
            reset();
        }
    }

    public int getFolds() {
        return this.m_Folds;
    }

    public String foldsTipText() {
        return "The number of folds to use in the cross-validation; use -1 for leave-one-out cross-validation (LOOCV); overrides the value defined by the fold generator scheme.";
    }

    public void setRelationName(String str) {
        this.m_RelationName = str;
    }

    public String getRelationName() {
        return this.m_RelationName;
    }

    public String relationNameTipText() {
        return "The placeholders for creating the relation name: @ = original relation name, $T = type (train/test), $N = current fold number; overrides the value defined by the fold generator scheme.";
    }

    @Override // adams.data.weka.InstancesViewCreator
    public void setCreateView(boolean z) {
        this.m_CreateView = z;
        reset();
    }

    @Override // adams.data.weka.InstancesViewCreator
    public boolean getCreateView() {
        return this.m_CreateView;
    }

    @Override // adams.data.weka.InstancesViewCreator
    public String createViewTipText() {
        return "If enabled, views of the dataset are created instead of actual copies; overrides the value defined by the fold generator scheme.";
    }

    public void setGenerator(CrossValidationFoldGenerator crossValidationFoldGenerator) {
        this.m_Generator = crossValidationFoldGenerator;
        reset();
    }

    public CrossValidationFoldGenerator getGenerator() {
        return this.m_Generator;
    }

    public String generatorTipText() {
        return "The scheme to use for generating the folds; the actor options take precedence over the scheme's ones.";
    }

    protected void pruneBackup() {
        super.pruneBackup();
        pruneBackup("generator");
    }

    protected Hashtable<String, Object> backupState() {
        Hashtable<String, Object> backupState = super.backupState();
        if (this.m_ActualGenerator != null) {
            backupState.put("generator", this.m_ActualGenerator);
        }
        return backupState;
    }

    protected void restoreState(Hashtable<String, Object> hashtable) {
        if (hashtable.containsKey("generator")) {
            this.m_ActualGenerator = (CrossValidationFoldGenerator) hashtable.get("generator");
            hashtable.remove("generator");
        }
        super.restoreState(hashtable);
    }

    protected void reset() {
        super.reset();
        this.m_ActualGenerator = null;
    }

    protected String doExecute() {
        String str = null;
        try {
            this.m_ActualGenerator = (CrossValidationFoldGenerator) ObjectCopyHelper.copyObject(this.m_Generator);
            this.m_ActualGenerator.setData((Instances) this.m_InputToken.getPayload());
            this.m_ActualGenerator.setNumFolds(this.m_Folds);
            this.m_ActualGenerator.setSeed(this.m_Seed);
            this.m_ActualGenerator.setStratify(true);
            this.m_ActualGenerator.setRandomize(true);
            this.m_ActualGenerator.setRelationName(this.m_RelationName);
            this.m_ActualGenerator.setUseViews(this.m_CreateView);
        } catch (Exception e) {
            str = handleException("Failed to initialize fold generator!", e);
        }
        return str;
    }

    public boolean hasPendingOutput() {
        return this.m_ActualGenerator != null && this.m_ActualGenerator.hasNext();
    }

    public Token output() {
        return new Token(this.m_ActualGenerator.mo157next());
    }

    public void wrapUp() {
        this.m_ActualGenerator = null;
        super.wrapUp();
    }

    public void stopExecution() {
        if (this.m_ActualGenerator != null && (this.m_ActualGenerator instanceof Stoppable)) {
            this.m_ActualGenerator.stopExecution();
        }
        super.stopExecution();
    }
}
