/*
 * Decompiled with CFR 0.152.
 */
package adams.flow.transformer;

import adams.core.Randomizable;
import adams.flow.core.AbstractActor;
import adams.flow.core.Token;
import adams.flow.provenance.ActorType;
import adams.flow.provenance.Provenance;
import adams.flow.provenance.ProvenanceContainer;
import adams.flow.provenance.ProvenanceInformation;
import adams.flow.provenance.ProvenanceSupporter;
import adams.flow.transformer.AbstractGlobalWekaClassifierEvaluator;
import java.util.Random;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.evaluation.output.prediction.Null;
import weka.core.Instances;

public class WekaCrossValidationEvaluator
extends AbstractGlobalWekaClassifierEvaluator
implements Randomizable,
ProvenanceSupporter {
    private static final long serialVersionUID = -3019442578354930841L;
    protected int m_Folds;
    protected long m_Seed;

    public String globalInfo() {
        return "Cross-validates a classifier on an incoming dataset. The classifier setup being used in the evaluation is a global 'Classifier' actor.";
    }

    public void defineOptions() {
        super.defineOptions();
        this.m_OptionManager.add("seed", "seed", (Object)1L);
        this.m_OptionManager.add("folds", "folds", (Object)10, (Number)-1, null);
    }

    public String getQuickInfo() {
        String variable = this.getOptionManager().getVariableForProperty("classifier");
        String result = variable != null ? variable : this.m_Classifier.toString();
        result = result + ", ";
        variable = this.getOptionManager().getVariableForProperty("folds");
        result = variable != null ? result + variable : result + this.m_Folds;
        result = result + " folds";
        return result;
    }

    public String classifierTipText() {
        return "The global classifier actor to cross-validate on the input data.";
    }

    public void setFolds(int value) {
        if (value == -1 || value >= 2) {
            this.m_Folds = value;
            this.reset();
        } else {
            this.getSystemErr().println("Number of folds must be >=2 or -1 for LOOCV, provided: " + value);
        }
    }

    public int getFolds() {
        return this.m_Folds;
    }

    public String foldsTipText() {
        return "The number of folds to use in the cross-validation; use -1 for leave-one-out cross-validation (LOOCV).";
    }

    public void setSeed(long value) {
        this.m_Seed = value;
        this.reset();
    }

    public long getSeed() {
        return this.m_Seed;
    }

    public String seedTipText() {
        return "The seed value for the cross-validation (used for randomization).";
    }

    public Class[] accepts() {
        return new Class[]{Instances.class};
    }

    protected String doExecute() {
        String result = null;
        try {
            Classifier cls = this.getClassifierInstance();
            if (cls == null) {
                throw new IllegalStateException("Classifier '" + this.getClassifier() + "' not found!");
            }
            Instances data = (Instances)this.m_InputToken.getPayload();
            this.m_Output.setHeader(data);
            Evaluation eval = new Evaluation(data);
            int folds = this.m_Folds;
            if (folds == -1) {
                folds = data.numInstances();
            }
            eval.setDiscardPredictions(this.m_DiscardPredictions);
            eval.crossValidateModel(cls, data, folds, new Random(this.m_Seed), new Object[]{this.m_Output});
            this.m_OutputToken = this.m_Output instanceof Null ? new Token((Object)eval) : new Token((Object)this.m_Output.getBuffer().toString());
        }
        catch (Exception e) {
            this.m_OutputToken = null;
            result = e.toString();
            this.getSystemErr().printStackTrace((Throwable)e);
        }
        if (this.m_OutputToken != null) {
            this.updateProvenance((ProvenanceContainer)this.m_OutputToken);
        }
        return result;
    }

    public void updateProvenance(ProvenanceContainer cont) {
        if (Provenance.getSingleton().isEnabled()) {
            cont.setProvenance(this.m_InputToken.getProvenance());
            cont.addProvenance(new ProvenanceInformation(ActorType.EVALUATOR, this.m_InputToken.getPayload().getClass(), (AbstractActor)this, this.m_OutputToken.getPayload().getClass()));
        }
    }
}

