/*
 * Decompiled with CFR 0.152.
 */
package adams.flow.transformer;

import adams.core.Randomizable;
import adams.flow.container.WekaTrainTestSetContainer;
import adams.flow.core.AbstractActor;
import adams.flow.core.Token;
import adams.flow.provenance.ActorType;
import adams.flow.provenance.Provenance;
import adams.flow.provenance.ProvenanceContainer;
import adams.flow.provenance.ProvenanceInformation;
import adams.flow.provenance.ProvenanceSupporter;
import adams.flow.transformer.AbstractTransformer;
import java.util.Hashtable;
import java.util.Random;
import weka.core.Instances;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class WekaCrossValidationSplit
extends AbstractTransformer
implements Randomizable,
ProvenanceSupporter {
    private static final long serialVersionUID = 4026105903223741240L;
    public static final String BACKUP_CURRENTFOLD = "current fold";
    public static final String BACKUP_ACTUALFOLDS = "actual folds";
    protected long m_Seed;
    protected int m_Folds;
    protected String m_RelationName;
    protected int m_CurrentFold;
    protected int m_ActualFolds;
    protected Instances m_Data;
    protected Random m_Random;

    public String globalInfo() {
        return "Generates train/test pairs like during a cross-validation run. It is possible to generate pairs for leave-one-out cross-validation (LOOCV) as well.\nIt is essential that a class attribute is set. In case of a nominal class attribute, the data gets stratified automatically.\nEach of the pairs gets forwarded as a container. The training set can be accessed in the container with 'Train' and the test set with 'Test'.";
    }

    public void defineOptions() {
        super.defineOptions();
        this.m_OptionManager.add("seed", "seed", (Object)1L);
        this.m_OptionManager.add("folds", "folds", (Object)10);
        this.m_OptionManager.add("relation", "relationName", (Object)"@");
    }

    public Class[] accepts() {
        return new Class[]{Instances.class};
    }

    public Class[] generates() {
        return new Class[]{WekaTrainTestSetContainer.class};
    }

    public void setSeed(long value) {
        this.m_Seed = value;
        this.reset();
    }

    public long getSeed() {
        return this.m_Seed;
    }

    public String seedTipText() {
        return "The seed value for the randomization.";
    }

    public void setFolds(int value) {
        if (value >= 2 || value == -1) {
            this.m_Folds = value;
            this.reset();
        } else {
            this.getSystemErr().println("Folds must be >=2 or -1 for LOOCV (provided: " + value + ")!");
        }
    }

    public int getFolds() {
        return this.m_Folds;
    }

    public String foldsTipText() {
        return "The folds to use; using '-1' will generate folds for leave-one-out cross-validation (LOOCV).";
    }

    public void setRelationName(String value) {
        this.m_RelationName = value;
    }

    public String getRelationName() {
        return this.m_RelationName;
    }

    public String relationNameTipText() {
        return "The placeholders for creating the relation name: @ = original relation name, $T = type (train/test), $N = current fold number.";
    }

    protected void pruneBackup() {
        super.pruneBackup();
        this.pruneBackup(BACKUP_CURRENTFOLD);
        this.pruneBackup(BACKUP_ACTUALFOLDS);
    }

    protected Hashtable<String, Object> backupState() {
        Hashtable result = super.backupState();
        result.put(BACKUP_CURRENTFOLD, this.m_CurrentFold);
        result.put(BACKUP_ACTUALFOLDS, this.m_ActualFolds);
        return result;
    }

    protected void restoreState(Hashtable<String, Object> state) {
        if (state.containsKey(BACKUP_CURRENTFOLD)) {
            this.m_CurrentFold = (Integer)state.get(BACKUP_CURRENTFOLD);
            state.remove(BACKUP_CURRENTFOLD);
        }
        if (state.containsKey(BACKUP_ACTUALFOLDS)) {
            this.m_ActualFolds = (Integer)state.get(BACKUP_ACTUALFOLDS);
            state.remove(BACKUP_ACTUALFOLDS);
        }
        super.restoreState(state);
    }

    protected void reset() {
        super.reset();
        this.m_CurrentFold = 0;
        this.m_ActualFolds = 0;
    }

    protected String doExecute() {
        String result = null;
        this.m_Data = (Instances)this.m_InputToken.getPayload();
        if (this.m_Data.classIndex() == -1) {
            result = "No class attribute set!";
        } else {
            this.m_Data = new Instances(this.m_Data);
            this.m_CurrentFold = 1;
            this.m_ActualFolds = this.m_Folds;
            if (this.m_ActualFolds < 2) {
                this.m_ActualFolds = this.m_Data.numInstances();
            }
            this.m_Random = new Random(this.m_Seed);
            this.m_Data.randomize(this.m_Random);
        }
        return result;
    }

    public boolean hasPendingOutput() {
        return this.m_Data != null && this.m_CurrentFold > 0 && this.m_CurrentFold <= this.m_ActualFolds;
    }

    protected String createRelationName(boolean train) {
        StringBuilder result = new StringBuilder();
        String name = this.m_RelationName;
        while (name.length() > 0) {
            int len;
            if (name.startsWith("@")) {
                len = 1;
                result.append(this.m_Data.relationName());
            } else if (name.startsWith("$T")) {
                len = 2;
                if (train) {
                    result.append("train");
                } else {
                    result.append("test");
                }
            } else if (name.startsWith("$N")) {
                len = 2;
                result.append(Integer.toString(this.m_CurrentFold));
            } else {
                len = 1;
                result.append(name.charAt(0));
            }
            name = name.substring(len);
        }
        return result.toString();
    }

    public Token output() {
        Instances train = this.m_Data.trainCV(this.m_ActualFolds, this.m_CurrentFold - 1, this.m_Random);
        Instances test = this.m_Data.testCV(this.m_ActualFolds, this.m_CurrentFold - 1);
        train.setRelationName(this.createRelationName(true));
        test.setRelationName(this.createRelationName(false));
        Token result = new Token((Object)new WekaTrainTestSetContainer(train, test, this.m_Seed, this.m_CurrentFold, this.m_ActualFolds));
        ++this.m_CurrentFold;
        this.updateProvenance((ProvenanceContainer)result);
        return result;
    }

    public void updateProvenance(ProvenanceContainer cont) {
        if (Provenance.getSingleton().isEnabled()) {
            cont.setProvenance(this.m_InputToken.getProvenance());
            cont.addProvenance(new ProvenanceInformation(ActorType.DATAGENERATOR, this.m_InputToken.getPayload().getClass(), (AbstractActor)this, ((Token)cont).getPayload().getClass()));
        }
    }

    public void wrapUp() {
        this.m_Data = null;
        this.m_Random = null;
        this.m_CurrentFold = 0;
        this.m_ActualFolds = 0;
        super.wrapUp();
    }
}

