/*
 * Decompiled with CFR 0.152.
 */
package adams.flow.transformer;

import adams.core.Performance;
import adams.core.QuickInfoHelper;
import adams.core.Randomizable;
import adams.core.ThreadLimiter;
import adams.core.option.OptionHandler;
import adams.core.option.OptionUtils;
import adams.data.weka.InstancesViewSupporter;
import adams.flow.container.WekaEvaluationContainer;
import adams.flow.core.Actor;
import adams.flow.core.ActorUtils;
import adams.flow.core.Token;
import adams.flow.provenance.ActorType;
import adams.flow.provenance.Provenance;
import adams.flow.provenance.ProvenanceContainer;
import adams.flow.provenance.ProvenanceInformation;
import adams.flow.provenance.ProvenanceSupporter;
import adams.flow.standalone.JobRunnerSetup;
import adams.flow.transformer.AbstractCallableWekaClassifierEvaluator;
import adams.multiprocess.WekaCrossValidationExecution;
import weka.classifiers.Classifier;
import weka.classifiers.evaluation.output.prediction.Null;
import weka.core.Instances;

public class WekaCrossValidationEvaluator
extends AbstractCallableWekaClassifierEvaluator
implements Randomizable,
ProvenanceSupporter,
ThreadLimiter,
InstancesViewSupporter {
    private static final long serialVersionUID = -3019442578354930841L;
    protected int m_Folds;
    protected long m_Seed;
    protected int m_NumThreads;
    protected boolean m_UseViews;
    protected WekaCrossValidationExecution m_CrossValidation;
    protected transient JobRunnerSetup m_JobRunnerSetup;

    public String globalInfo() {
        return "Cross-validates a classifier on an incoming dataset. The classifier setup being used in the evaluation is a callable 'Classifier' actor.";
    }

    @Override
    public void defineOptions() {
        super.defineOptions();
        this.m_OptionManager.add("seed", "seed", (Object)1L);
        this.m_OptionManager.add("folds", "folds", (Object)10, (Number)-1, null);
        this.m_OptionManager.add("num-threads", "numThreads", (Object)1);
        this.m_OptionManager.add("use-views", "useViews", (Object)false);
    }

    @Override
    public String getQuickInfo() {
        String result = super.getQuickInfo();
        result = result + QuickInfoHelper.toString((OptionHandler)this, (String)"folds", (Object)this.m_Folds, (String)", folds: ");
        result = result + QuickInfoHelper.toString((OptionHandler)this, (String)"seed", (Object)this.m_Seed, (String)", seed: ");
        result = result + QuickInfoHelper.toString((OptionHandler)this, (String)"numThreads", (Object)Performance.getNumThreadsQuickInfo((int)this.m_NumThreads), (String)", ");
        String value = QuickInfoHelper.toString((OptionHandler)this, (String)"useViews", (boolean)this.m_UseViews, (String)", using views");
        if (value != null) {
            result = result + value;
        }
        return result;
    }

    @Override
    public String classifierTipText() {
        return "The callable classifier actor to cross-validate on the input data.";
    }

    @Override
    public String outputTipText() {
        return "The class for generating prediction output; if 'Null' is used, then an Evaluation object is forwarded instead of a String; not used when using parallel execution.";
    }

    public void setFolds(int value) {
        if (value == -1 || value >= 2) {
            this.m_Folds = value;
            this.reset();
        } else {
            this.getLogger().severe("Number of folds must be >=2 or -1 for LOOCV, provided: " + value);
        }
    }

    public int getFolds() {
        return this.m_Folds;
    }

    public String foldsTipText() {
        return "The number of folds to use in the cross-validation; use -1 for leave-one-out cross-validation (LOOCV).";
    }

    public void setSeed(long value) {
        this.m_Seed = value;
        this.reset();
    }

    public long getSeed() {
        return this.m_Seed;
    }

    public String seedTipText() {
        return "The seed value for the cross-validation (used for randomization).";
    }

    public void setNumThreads(int value) {
        this.m_NumThreads = value;
        this.reset();
    }

    public int getNumThreads() {
        return this.m_NumThreads;
    }

    public String numThreadsTipText() {
        return Performance.getNumThreadsHelp();
    }

    @Override
    public void setUseViews(boolean value) {
        this.m_UseViews = value;
        this.reset();
    }

    @Override
    public boolean getUseViews() {
        return this.m_UseViews;
    }

    public String useViewsTipText() {
        return "If enabled, views of the dataset are being used instead of actual copies, to conserve memory.";
    }

    public Class[] accepts() {
        return new Class[]{Instances.class};
    }

    @Override
    public String setUp() {
        String result = super.setUp();
        if (result == null) {
            this.m_JobRunnerSetup = (JobRunnerSetup)ActorUtils.findClosestType((Actor)this, JobRunnerSetup.class);
        }
        return result;
    }

    protected String doExecute() {
        String result;
        int[] indices = null;
        try {
            Classifier cls = this.getClassifierInstance();
            if (cls == null) {
                throw new IllegalStateException("Classifier '" + this.getClassifier() + "' not found!");
            }
            if (this.isLoggingEnabled()) {
                this.getLogger().info(OptionUtils.getCommandLine((Object)cls));
            }
            Instances data = (Instances)this.m_InputToken.getPayload();
            this.m_CrossValidation = new WekaCrossValidationExecution();
            this.m_CrossValidation.setJobRunnerSetup(this.m_JobRunnerSetup);
            this.m_CrossValidation.setClassifier(cls);
            this.m_CrossValidation.setData(data);
            this.m_CrossValidation.setFolds(this.m_Folds);
            this.m_CrossValidation.setSeed(this.m_Seed);
            this.m_CrossValidation.setUseViews(this.m_UseViews);
            this.m_CrossValidation.setDiscardPredictions(this.m_DiscardPredictions);
            this.m_CrossValidation.setNumThreads(this.m_NumThreads);
            this.m_CrossValidation.setOutput(this.m_Output);
            result = this.m_CrossValidation.execute();
            if (!this.m_CrossValidation.isStopped()) {
                indices = this.m_CrossValidation.getOriginalIndices();
                if (this.m_CrossValidation.isSingleThreaded()) {
                    if (this.m_Output instanceof Null) {
                        this.m_OutputToken = new Token((Object)new WekaEvaluationContainer(this.m_CrossValidation.getEvaluation()));
                    } else {
                        if (this.m_CrossValidation.getOutputBuffer() != null) {
                            this.m_OutputBuffer.append(this.m_CrossValidation.getOutputBuffer().toString());
                        }
                        this.m_OutputToken = this.m_AlwaysUseContainer ? new Token((Object)new WekaEvaluationContainer(this.m_CrossValidation.getEvaluation(), null, this.m_Output.getBuffer().toString())) : new Token((Object)this.m_Output.getBuffer().toString());
                    }
                } else {
                    this.m_OutputToken = new Token((Object)new WekaEvaluationContainer(this.m_CrossValidation.getEvaluation()));
                }
            }
        }
        catch (Exception e) {
            this.m_OutputToken = null;
            result = this.handleException("Failed to cross-validate classifier: ", e);
        }
        if (this.m_OutputToken != null) {
            if (indices != null && this.m_OutputToken.getPayload() instanceof WekaEvaluationContainer) {
                ((WekaEvaluationContainer)((Object)this.m_OutputToken.getPayload())).setValue("Original indices", indices);
            }
            this.updateProvenance((ProvenanceContainer)this.m_OutputToken);
        }
        return result;
    }

    public void updateProvenance(ProvenanceContainer cont) {
        if (Provenance.getSingleton().isEnabled()) {
            if (this.m_InputToken.hasProvenance()) {
                cont.setProvenance(this.m_InputToken.getProvenance().getClone());
            }
            cont.addProvenance(new ProvenanceInformation(ActorType.EVALUATOR, this.m_InputToken.getPayload().getClass(), (Actor)this, this.m_OutputToken.getPayload().getClass()));
        }
    }

    public void stopExecution() {
        if (this.m_CrossValidation != null) {
            this.m_CrossValidation.stopExecution();
        }
        super.stopExecution();
    }
}

