/*
 * Decompiled with CFR 0.152.
 */
package adams.gui.tools.wekainvestigator.tab.classifytab.evaluation;

import adams.core.MessageCollection;
import adams.core.Properties;
import adams.core.StoppableWithFeedback;
import adams.core.Utils;
import adams.core.option.OptionUtils;
import adams.data.spreadsheet.MetaData;
import adams.gui.chooser.SelectOptionPanel;
import adams.gui.core.BaseCheckBox;
import adams.gui.core.BaseComboBox;
import adams.gui.core.NumberTextField;
import adams.gui.core.ParameterPanel;
import adams.gui.goe.GenericObjectEditorPanel;
import adams.gui.tools.wekainvestigator.data.DataContainer;
import adams.gui.tools.wekainvestigator.evaluation.DatasetHelper;
import adams.gui.tools.wekainvestigator.tab.AbstractInvestigatorTab;
import adams.gui.tools.wekainvestigator.tab.ClassifyTab;
import adams.gui.tools.wekainvestigator.tab.classifytab.ResultItem;
import adams.gui.tools.wekainvestigator.tab.classifytab.evaluation.AbstractClassifierEvaluation;
import adams.gui.tools.wekainvestigator.tab.classifytab.evaluation.finalmodel.AbstractFinalModelGenerator;
import adams.gui.tools.wekainvestigator.tab.classifytab.evaluation.finalmodel.Simple;
import adams.multiprocess.JobRunner;
import adams.multiprocess.LocalJobRunner;
import adams.multiprocess.WekaCrossValidationExecution;
import java.awt.Component;
import java.util.List;
import java.util.Map;
import java.util.Set;
import javax.swing.DefaultComboBoxModel;
import javax.swing.JSpinner;
import javax.swing.SpinnerNumberModel;
import javax.swing.event.DocumentEvent;
import javax.swing.event.DocumentListener;
import weka.classifiers.Classifier;
import weka.classifiers.CrossValidationFoldGenerator;
import weka.classifiers.DefaultCrossValidationFoldGenerator;
import weka.core.Capabilities;
import weka.core.Instances;

public class CrossValidation
extends AbstractClassifierEvaluation
implements StoppableWithFeedback {
    private static final long serialVersionUID = 1175400993991698944L;
    public static final String KEY_DATASET = "dataset";
    public static final String KEY_FOLDS = "folds";
    public static final String KEY_PERFOLDOUTPUT = "perfoldoutput";
    public static final String KEY_SEED = "seed";
    public static final String KEY_JOBRUNNER = "jobrunner";
    public static final String KEY_ADDITIONAL = "additional";
    public static final String KEY_USEVIEWS = "useviews";
    public static final String KEY_GENERATOR = "generator";
    public static final String KEY_DISCARDPREDICTIONS = "discardpredictions";
    public static final String KEY_FINALMODEL = "finalmodel";
    protected ParameterPanel m_PanelParameters;
    protected BaseComboBox<String> m_ComboBoxDatasets;
    protected DefaultComboBoxModel<String> m_ModelDatasets;
    protected JSpinner m_SpinnerFolds;
    protected BaseCheckBox m_CheckBoxPerFoldOutput;
    protected NumberTextField m_TextSeed;
    protected GenericObjectEditorPanel m_GOEJobRunner;
    protected SelectOptionPanel m_SelectAdditionalAttributes;
    protected BaseCheckBox m_CheckBoxUseViews;
    protected GenericObjectEditorPanel m_GOEGenerator;
    protected BaseCheckBox m_CheckBoxDiscardPredictions;
    protected GenericObjectEditorPanel m_GOEFinalModel;
    protected WekaCrossValidationExecution m_CrossValidation;

    public String globalInfo() {
        return "Performs cross-validation.";
    }

    @Override
    protected void initialize() {
        super.initialize();
        this.m_CrossValidation = null;
    }

    @Override
    protected void initGUI() {
        AbstractFinalModelGenerator finalmodel;
        CrossValidationFoldGenerator generator;
        LocalJobRunner jobrunner;
        super.initGUI();
        Properties props = CrossValidation.getProperties();
        this.m_PanelParameters = new ParameterPanel();
        this.m_PanelOptions.add((Component)this.m_PanelParameters, "Center");
        this.m_ModelDatasets = new DefaultComboBoxModel();
        this.m_ComboBoxDatasets = new BaseComboBox(this.m_ModelDatasets);
        this.m_ComboBoxDatasets.addActionListener(e -> this.update());
        this.m_PanelParameters.addParameter("Dataset", this.m_ComboBoxDatasets);
        this.m_SpinnerFolds = new JSpinner();
        ((SpinnerNumberModel)this.m_SpinnerFolds.getModel()).setMinimum(Integer.valueOf(-1));
        ((SpinnerNumberModel)this.m_SpinnerFolds.getModel()).setStepSize(1);
        this.m_SpinnerFolds.setValue(props.getInteger("Classify.NumFolds", Integer.valueOf(10)));
        this.m_SpinnerFolds.setToolTipText("The number of folds to use (< 2 for LOO-CV)");
        this.m_SpinnerFolds.addChangeListener(e -> this.update());
        this.m_PanelParameters.addParameter("Folds", (Component)this.m_SpinnerFolds);
        this.m_CheckBoxPerFoldOutput = new BaseCheckBox();
        this.m_CheckBoxPerFoldOutput.setSelected(props.getBoolean("Classify.PerFoldOutput", Boolean.valueOf(false)).booleanValue());
        this.m_CheckBoxPerFoldOutput.setToolTipText("Keep separate evaluations per fold to inspect per fold performance");
        this.m_CheckBoxPerFoldOutput.addActionListener(e -> this.update());
        this.m_PanelParameters.addParameter("Per fold output", (Component)this.m_CheckBoxPerFoldOutput);
        this.m_TextSeed = new NumberTextField(NumberTextField.Type.INTEGER, "" + props.getInteger("Classify.Seed", Integer.valueOf(1)));
        this.m_TextSeed.setToolTipText("The seed value for randomizing the data");
        this.m_TextSeed.getDocument().addDocumentListener(new DocumentListener(){

            @Override
            public void insertUpdate(DocumentEvent e) {
                CrossValidation.this.update();
            }

            @Override
            public void removeUpdate(DocumentEvent e) {
                CrossValidation.this.update();
            }

            @Override
            public void changedUpdate(DocumentEvent e) {
                CrossValidation.this.update();
            }
        });
        this.m_PanelParameters.addParameter("Seed", (Component)this.m_TextSeed);
        try {
            jobrunner = (JobRunner)OptionUtils.forCommandLine(JobRunner.class, (String)props.getProperty("Classify.JobRunner", new LocalJobRunner().toCommandLine()));
        }
        catch (Exception e2) {
            jobrunner = new LocalJobRunner();
        }
        this.m_GOEJobRunner = new GenericObjectEditorPanel(JobRunner.class, (Object)jobrunner, true);
        this.m_GOEJobRunner.setToolTipText("Whether to execute the jobs locally or remotely");
        this.m_GOEJobRunner.addChangeListener(e -> this.update());
        this.m_PanelParameters.addParameter("Job runner", (Component)this.m_GOEJobRunner);
        this.m_CheckBoxUseViews = new BaseCheckBox();
        this.m_CheckBoxUseViews.setSelected(props.getBoolean("Classify.UseViews", Boolean.valueOf(false)).booleanValue());
        this.m_CheckBoxUseViews.setToolTipText("Save memory by using views instead of creating copies of datasets?");
        this.m_CheckBoxUseViews.addActionListener(e -> this.update());
        this.m_PanelParameters.addParameter("Use views", (Component)this.m_CheckBoxUseViews);
        try {
            generator = (CrossValidationFoldGenerator)OptionUtils.forCommandLine(CrossValidationFoldGenerator.class, (String)props.getProperty("Classify.CrossValidationFoldGenerator", new DefaultCrossValidationFoldGenerator().toCommandLine()));
        }
        catch (Exception e3) {
            generator = new DefaultCrossValidationFoldGenerator();
        }
        this.m_GOEGenerator = new GenericObjectEditorPanel(CrossValidationFoldGenerator.class, (Object)generator, true);
        this.m_GOEGenerator.addChangeListener(e -> this.update());
        this.m_PanelParameters.addParameter("Generator", (Component)this.m_GOEGenerator);
        this.m_CheckBoxDiscardPredictions = new BaseCheckBox();
        this.m_CheckBoxDiscardPredictions.setSelected(props.getBoolean("Classify.DiscardPredictions", Boolean.valueOf(false)).booleanValue());
        this.m_CheckBoxDiscardPredictions.setToolTipText("Save memory by discarding predictions?");
        this.m_CheckBoxDiscardPredictions.addActionListener(e -> this.update());
        this.m_PanelParameters.addParameter("Discard predictions", (Component)this.m_CheckBoxDiscardPredictions);
        this.m_SelectAdditionalAttributes = new SelectOptionPanel();
        this.m_SelectAdditionalAttributes.setCurrent((Object)new String[0]);
        this.m_SelectAdditionalAttributes.setMultiSelect(true);
        this.m_SelectAdditionalAttributes.setLenient(true);
        this.m_SelectAdditionalAttributes.setDialogTitle("Select additional attributes");
        this.m_SelectAdditionalAttributes.setToolTipText("Additional attributes to make available in plots");
        this.m_PanelParameters.addParameter("Additional attributes", (Component)this.m_SelectAdditionalAttributes);
        try {
            finalmodel = (AbstractFinalModelGenerator)OptionUtils.forCommandLine(AbstractFinalModelGenerator.class, (String)props.getProperty("Classify.CrossValidationFinalModel", new Simple().toCommandLine()));
        }
        catch (Exception e4) {
            finalmodel = new Simple();
        }
        this.m_GOEFinalModel = new GenericObjectEditorPanel(AbstractFinalModelGenerator.class, (Object)finalmodel, true);
        this.m_GOEFinalModel.setToolTipText("How to produce a final model");
        this.m_GOEFinalModel.addChangeListener(e -> this.update());
        this.m_PanelParameters.addParameter("Final model", (Component)this.m_GOEFinalModel);
    }

    @Override
    public String getName() {
        return "Cross-validation";
    }

    @Override
    public String canEvaluate(Classifier classifier) {
        if (!this.isValidDataIndex(this.m_ComboBoxDatasets)) {
            return "No data available!";
        }
        if (!Utils.isInteger((String)this.m_TextSeed.getText())) {
            return "Seed value is not an integer!";
        }
        Instances data = ((ClassifyTab)this.getOwner()).getData().get(this.m_ComboBoxDatasets.getSelectedIndex()).getData();
        Capabilities caps = classifier.getCapabilities();
        try {
            if (!caps.test(data)) {
                if (caps.getFailReason() != null) {
                    return caps.getFailReason().getMessage();
                }
                return "Classifier cannot handle data!";
            }
        }
        catch (Exception e) {
            return "Classifier cannot handle data: " + e;
        }
        return null;
    }

    @Override
    public ResultItem init(Classifier classifier) throws Exception {
        Instances data = ((ClassifyTab)this.getOwner()).getData().get(this.m_ComboBoxDatasets.getSelectedIndex()).getData();
        ResultItem result = new ResultItem(classifier, new Instances(data, 0));
        return result;
    }

    @Override
    protected void doEvaluate(Classifier classifier, ResultItem item) throws Exception {
        String msg = this.canEvaluate(classifier);
        if (msg != null) {
            throw new IllegalArgumentException("Cannot evaluate classifier!\n" + msg);
        }
        DataContainer dataCont = ((ClassifyTab)this.getOwner()).getData().get(this.m_ComboBoxDatasets.getSelectedIndex());
        Instances data = dataCont.getData();
        AbstractFinalModelGenerator finalModel = (AbstractFinalModelGenerator)((Object)this.m_GOEFinalModel.getCurrent());
        boolean views = this.m_CheckBoxUseViews.isSelected();
        boolean discard = this.m_CheckBoxDiscardPredictions.isSelected();
        int seed = this.m_TextSeed.getValue().intValue();
        int folds = ((Number)this.m_SpinnerFolds.getValue()).intValue();
        boolean sepFolds = this.m_CheckBoxPerFoldOutput.isSelected();
        JobRunner jobrunner = (JobRunner)this.m_GOEJobRunner.getCurrent();
        CrossValidationFoldGenerator generator = (CrossValidationFoldGenerator)this.m_GOEGenerator.getCurrent();
        MetaData runInfo = new MetaData();
        runInfo.add("Classifier", (Object)OptionUtils.getCommandLine((Object)classifier));
        runInfo.add("Seed", (Object)seed);
        runInfo.add("Folds", (Object)folds);
        runInfo.add("Separate folds", (Object)sepFolds);
        runInfo.add("JobRunner", (Object)jobrunner.toCommandLine());
        runInfo.add("Dataset ID", (Object)dataCont.getID());
        runInfo.add("Relation", (Object)data.relationName());
        runInfo.add("# Attributes", (Object)data.numAttributes());
        runInfo.add("# Instances", (Object)data.numInstances());
        runInfo.add("Class attribute", (Object)data.classAttribute().name());
        runInfo.add("Use views", (Object)views);
        runInfo.add("Fold generator", (Object)generator.toCommandLine());
        runInfo.add("Discard predictions", (Object)discard);
        if (((String[])this.m_SelectAdditionalAttributes.getCurrent()).length > 0) {
            runInfo.add("Additional attributes: ", (Object)Utils.flatten((Object[])((Object[])this.m_SelectAdditionalAttributes.getCurrent()), (String)", "));
        }
        this.m_CrossValidation = new WekaCrossValidationExecution();
        this.m_CrossValidation.setClassifier(classifier);
        this.m_CrossValidation.setData(data);
        this.m_CrossValidation.setFolds(folds);
        this.m_CrossValidation.setSeparateFolds(sepFolds);
        this.m_CrossValidation.setSeed(seed);
        this.m_CrossValidation.setJobRunner(jobrunner);
        this.m_CrossValidation.setUseViews(views);
        this.m_CrossValidation.setGenerator((CrossValidationFoldGenerator)OptionUtils.shallowCopy((Object)generator));
        this.m_CrossValidation.setDiscardPredictions(discard);
        this.m_CrossValidation.setStatusMessageHandler(this);
        msg = this.m_CrossValidation.execute();
        if (msg != null) {
            throw new Exception("Failed to cross-validate:\n" + msg);
        }
        item.update(this.m_CrossValidation.getEvaluation(), sepFolds ? this.m_CrossValidation.getEvaluations() : null, null, sepFolds ? this.m_CrossValidation.getClassifiers() : null, runInfo, this.m_CrossValidation.getOriginalIndices(), this.transferAdditionalAttributes(this.m_SelectAdditionalAttributes, data));
        ((ClassifyTab)this.getOwner()).logMessage("Building final model on '" + dataCont.getID() + "/" + data.relationName() + "' using " + OptionUtils.getCommandLine((Object)classifier));
        finalModel.generate(this, data, item);
        this.m_CrossValidation = null;
    }

    @Override
    public void update() {
        if (this.getOwner() == null) {
            return;
        }
        if (((ClassifyTab)this.getOwner()).getOwner() == null) {
            return;
        }
        List<String> datasets = DatasetHelper.generateDatasetList(((ClassifyTab)this.getOwner()).getData());
        int index = DatasetHelper.indexOfDataset(((ClassifyTab)this.getOwner()).getData(), (String)this.m_ComboBoxDatasets.getSelectedItem());
        if (DatasetHelper.hasDataChanged(datasets, this.m_ModelDatasets)) {
            this.m_ModelDatasets = new DefaultComboBoxModel<String>(datasets.toArray(new String[datasets.size()]));
            this.m_ComboBoxDatasets.setModel(this.m_ModelDatasets);
            if (index == -1 && this.m_ModelDatasets.getSize() > 0) {
                this.m_ComboBoxDatasets.setSelectedIndex(0);
            } else if (index > -1) {
                this.m_ComboBoxDatasets.setSelectedIndex(index);
            }
        }
        this.fillWithAttributeNames(this.m_SelectAdditionalAttributes, this.m_ComboBoxDatasets.getSelectedIndex());
        ((ClassifyTab)this.getOwner()).updateButtons();
    }

    @Override
    public void activate(int index) {
        this.m_ComboBoxDatasets.setSelectedIndex(index);
    }

    public void stopExecution() {
        if (this.m_CrossValidation != null) {
            this.m_CrossValidation.stopExecution();
        }
    }

    public boolean isStopped() {
        return this.m_CrossValidation != null && this.m_CrossValidation.isStopped();
    }

    @Override
    public Map<String, Object> serialize(Set<AbstractInvestigatorTab.SerializationOption> options) {
        Map<String, Object> result = super.serialize(options);
        if (options.contains((Object)AbstractInvestigatorTab.SerializationOption.GUI)) {
            result.put(KEY_DATASET, this.m_ComboBoxDatasets.getSelectedIndex());
        }
        if (options.contains((Object)AbstractInvestigatorTab.SerializationOption.PARAMETERS)) {
            result.put(KEY_FOLDS, this.m_SpinnerFolds.getValue());
            result.put(KEY_PERFOLDOUTPUT, this.m_CheckBoxPerFoldOutput.isSelected());
            result.put(KEY_SEED, this.m_TextSeed.getValue().intValue());
            result.put(KEY_JOBRUNNER, OptionUtils.getCommandLine((Object)this.m_GOEJobRunner.getCurrent()));
            result.put(KEY_ADDITIONAL, this.m_SelectAdditionalAttributes.getCurrent());
            result.put(KEY_USEVIEWS, this.m_CheckBoxUseViews.isSelected());
            result.put(KEY_GENERATOR, OptionUtils.getCommandLine((Object)this.m_GOEGenerator.getCurrent()));
            result.put(KEY_DISCARDPREDICTIONS, this.m_CheckBoxDiscardPredictions.isSelected());
            result.put(KEY_FINALMODEL, OptionUtils.getCommandLine((Object)this.m_GOEFinalModel.getCurrent()));
        }
        return result;
    }

    @Override
    public void deserialize(Map<String, Object> data, MessageCollection errors) {
        super.deserialize(data, errors);
        if (data.containsKey(KEY_DATASET)) {
            this.m_ComboBoxDatasets.setSelectedIndex(((Number)data.get(KEY_DATASET)).intValue());
        }
        if (data.containsKey(KEY_FOLDS)) {
            this.m_SpinnerFolds.setValue(data.get(KEY_FOLDS));
        }
        if (data.containsKey(KEY_PERFOLDOUTPUT)) {
            this.m_CheckBoxPerFoldOutput.setSelected(((Boolean)data.get(KEY_PERFOLDOUTPUT)).booleanValue());
        }
        if (data.containsKey(KEY_SEED)) {
            this.m_TextSeed.setValue((Number)((Number)data.get(KEY_SEED)).intValue());
        }
        if (data.containsKey(KEY_JOBRUNNER)) {
            try {
                this.m_GOEJobRunner.setCurrent((Object)OptionUtils.forCommandLine(JobRunner.class, (String)((String)data.get(KEY_JOBRUNNER))));
            }
            catch (Exception e) {
                errors.add("Failed to parse jobrunner commandline: " + data.get(KEY_JOBRUNNER), (Throwable)e);
            }
        }
        if (data.containsKey(KEY_ADDITIONAL)) {
            this.m_SelectAdditionalAttributes.setCurrent((Object)this.listOrArray(data.get(KEY_ADDITIONAL)));
        }
        if (data.containsKey(KEY_USEVIEWS)) {
            this.m_CheckBoxUseViews.setSelected(((Boolean)data.get(KEY_USEVIEWS)).booleanValue());
        }
        if (data.containsKey(KEY_GENERATOR)) {
            try {
                this.m_GOEGenerator.setCurrent((Object)OptionUtils.forCommandLine(CrossValidationFoldGenerator.class, (String)((String)data.get(KEY_GENERATOR))));
            }
            catch (Exception e) {
                errors.add("Failed to parse generator commandline: " + data.get(KEY_GENERATOR), (Throwable)e);
            }
        }
        if (data.containsKey(KEY_DISCARDPREDICTIONS)) {
            this.m_CheckBoxDiscardPredictions.setSelected(((Boolean)data.get(KEY_DISCARDPREDICTIONS)).booleanValue());
        }
        if (data.containsKey(KEY_FINALMODEL)) {
            try {
                this.m_GOEFinalModel.setCurrent((Object)OptionUtils.forCommandLine(AbstractFinalModelGenerator.class, (String)((String)data.get(KEY_FINALMODEL))));
            }
            catch (Exception e) {
                errors.add("Failed to parse final model generator commandline: " + data.get(KEY_FINALMODEL), (Throwable)e);
            }
        }
    }
}

