package adams.gui.tools.wekainvestigator.tab.classifytab.evaluation;

import adams.core.MessageCollection;
import adams.core.Properties;
import adams.core.Utils;
import adams.core.option.OptionUtils;
import adams.data.spreadsheet.MetaData;
import adams.flow.container.WekaTrainTestSetContainer;
import adams.gui.chooser.SelectOptionPanel;
import adams.gui.core.BaseCheckBox;
import adams.gui.core.BaseComboBox;
import adams.gui.core.ParameterPanel;
import adams.gui.tools.wekainvestigator.data.DataContainer;
import adams.gui.tools.wekainvestigator.evaluation.DatasetHelper;
import adams.gui.tools.wekainvestigator.tab.AbstractInvestigatorTab;
import adams.gui.tools.wekainvestigator.tab.ClassifyTab;
import adams.gui.tools.wekainvestigator.tab.classifytab.ResultItem;
import java.util.List;
import java.util.Map;
import java.util.Set;
import javax.swing.DefaultComboBoxModel;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.TestingHelper;
import weka.core.Capabilities;
import weka.core.Instances;

/* loaded from: input_file:adams/gui/tools/wekainvestigator/tab/classifytab/evaluation/TrainValidateTestSet.class */
public class TrainValidateTestSet extends AbstractClassifierEvaluation {
    private static final long serialVersionUID = -4460266467650893551L;
    public static final String KEY_TRAIN = "train";
    public static final String KEY_VALIDATE = "validate";
    public static final String KEY_TEST = "test";
    public static final String KEY_ADDITIONAL = "additional";
    public static final String KEY_DISCARDPREDICTIONS = "discardpredictions";
    protected ParameterPanel m_PanelParameters;
    protected BaseComboBox<String> m_ComboBoxTrain;
    protected BaseComboBox<String> m_ComboBoxValidate;
    protected BaseComboBox<String> m_ComboBoxTest;
    protected DefaultComboBoxModel<String> m_ModelDatasets;
    protected SelectOptionPanel m_SelectAdditionalAttributes;
    protected BaseCheckBox m_CheckBoxDiscardPredictions;

    public String globalInfo() {
        return "Trains the classifier on the selected training set, validates it on the selected validation set and tests it against the selected test set.";
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // adams.gui.tools.wekainvestigator.evaluation.AbstractEvaluation
    public void initGUI() {
        super.initGUI();
        Properties properties = getProperties();
        this.m_PanelParameters = new ParameterPanel();
        this.m_PanelOptions.add(this.m_PanelParameters, "Center");
        this.m_ModelDatasets = new DefaultComboBoxModel<>();
        this.m_ComboBoxTrain = new BaseComboBox<>(this.m_ModelDatasets);
        this.m_ComboBoxTrain.addActionListener(actionEvent -> {
            update();
        });
        this.m_PanelParameters.addParameter("Train", this.m_ComboBoxTrain);
        this.m_ComboBoxValidate = new BaseComboBox<>(this.m_ModelDatasets);
        this.m_ComboBoxValidate.addActionListener(actionEvent2 -> {
            update();
        });
        this.m_PanelParameters.addParameter("Validate", this.m_ComboBoxValidate);
        this.m_ComboBoxTest = new BaseComboBox<>(this.m_ModelDatasets);
        this.m_ComboBoxTest.addActionListener(actionEvent3 -> {
            update();
        });
        this.m_PanelParameters.addParameter(WekaTrainTestSetContainer.VALUE_TEST, this.m_ComboBoxTest);
        this.m_SelectAdditionalAttributes = new SelectOptionPanel();
        this.m_SelectAdditionalAttributes.setCurrent(new String[0]);
        this.m_SelectAdditionalAttributes.setMultiSelect(true);
        this.m_SelectAdditionalAttributes.setLenient(true);
        this.m_SelectAdditionalAttributes.setDialogTitle("Select additional attributes");
        this.m_SelectAdditionalAttributes.setToolTipText("Additional attributes to make available in plots");
        this.m_PanelParameters.addParameter("Additional attributes", this.m_SelectAdditionalAttributes);
        this.m_CheckBoxDiscardPredictions = new BaseCheckBox();
        this.m_CheckBoxDiscardPredictions.setSelected(properties.getBoolean("Classify.DiscardPredictions", false).booleanValue());
        this.m_CheckBoxDiscardPredictions.setToolTipText("Save memory by discarding predictions?");
        this.m_CheckBoxDiscardPredictions.addActionListener(actionEvent4 -> {
            update();
        });
        this.m_PanelParameters.addParameter("Discard predictions", this.m_CheckBoxDiscardPredictions);
    }

    @Override // adams.gui.tools.wekainvestigator.evaluation.AbstractEvaluation
    public String getName() {
        return "Train/validate/test set";
    }

    @Override // adams.gui.tools.wekainvestigator.tab.classifytab.evaluation.AbstractClassifierEvaluation
    public String canEvaluate(Classifier classifier) {
        if (!isValidDataIndex(this.m_ComboBoxTrain)) {
            return "No train data available!";
        }
        if (!isValidDataIndex(this.m_ComboBoxValidate)) {
            return "No validate data available!";
        }
        if (!isValidDataIndex(this.m_ComboBoxTest)) {
            return "No test data available!";
        }
        Capabilities capabilities = classifier.getCapabilities();
        Instances data = getOwner().getData().get(this.m_ComboBoxTrain.getSelectedIndex()).getData();
        try {
            if (!capabilities.test(data)) {
                return capabilities.getFailReason() != null ? capabilities.getFailReason().getMessage() : "Classifier cannot handle train data!";
            }
            Capabilities capabilities2 = classifier.getCapabilities();
            capabilities2.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
            Instances data2 = getOwner().getData().get(this.m_ComboBoxValidate.getSelectedIndex()).getData();
            try {
                if (!capabilities2.test(data2)) {
                    return capabilities2.getFailReason() != null ? capabilities2.getFailReason().getMessage() : "Classifier cannot handle validate data!";
                }
                Capabilities capabilities3 = classifier.getCapabilities();
                capabilities3.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
                Instances data3 = getOwner().getData().get(this.m_ComboBoxTest.getSelectedIndex()).getData();
                try {
                    if (!capabilities3.test(data3)) {
                        return capabilities3.getFailReason() != null ? capabilities3.getFailReason().getMessage() : "Classifier cannot handle test data!";
                    }
                    if (!data.equalHeaders(data2)) {
                        return data.equalHeadersMsg(data2);
                    }
                    if (data.equalHeaders(data3)) {
                        return null;
                    }
                    return data.equalHeadersMsg(data3);
                } catch (Exception e) {
                    return "Classifier cannot handle data: " + e;
                }
            } catch (Exception e2) {
                return "Classifier cannot handle data: " + e2;
            }
        } catch (Exception e3) {
            return "Classifier cannot handle data: " + e3;
        }
    }

    @Override // adams.gui.tools.wekainvestigator.tab.classifytab.evaluation.AbstractClassifierEvaluation
    public ResultItem init(Classifier classifier) throws Exception {
        return new ResultItem(classifier, new Instances(getOwner().getData().get(this.m_ComboBoxTrain.getSelectedIndex()).getData(), 0));
    }

    @Override // adams.gui.tools.wekainvestigator.tab.classifytab.evaluation.AbstractClassifierEvaluation
    protected void doEvaluate(final Classifier classifier, ResultItem resultItem) throws Exception {
        String canEvaluate = canEvaluate(classifier);
        if (canEvaluate != null) {
            throw new IllegalArgumentException("Cannot evaluate classifier!\n" + canEvaluate);
        }
        DataContainer dataContainer = ((ClassifyTab) getOwner()).getData().get(this.m_ComboBoxTrain.getSelectedIndex());
        final DataContainer dataContainer2 = ((ClassifyTab) getOwner()).getData().get(this.m_ComboBoxValidate.getSelectedIndex());
        DataContainer dataContainer3 = ((ClassifyTab) getOwner()).getData().get(this.m_ComboBoxTest.getSelectedIndex());
        Instances data = dataContainer.getData();
        final Instances data2 = dataContainer2.getData();
        Instances data3 = dataContainer3.getData();
        boolean isSelected = this.m_CheckBoxDiscardPredictions.isSelected();
        MetaData metaData = new MetaData();
        metaData.add("Classifier", OptionUtils.getCommandLine(classifier));
        metaData.add("Train ID", Integer.valueOf(dataContainer.getID()));
        metaData.add("Validate ID", Integer.valueOf(dataContainer3.getID()));
        metaData.add("Test ID", Integer.valueOf(dataContainer3.getID()));
        metaData.add("Relation", data.relationName());
        metaData.add("# Attributes", Integer.valueOf(data.numAttributes()));
        metaData.add("# Instances (train)", Integer.valueOf(data.numInstances()));
        metaData.add("# Instances (validate)", Integer.valueOf(data2.numInstances()));
        metaData.add("# Instances (test)", Integer.valueOf(data3.numInstances()));
        metaData.add("Class attribute", data.classAttribute().name());
        metaData.add("Discard predictions", Boolean.valueOf(isSelected));
        if (((String[]) this.m_SelectAdditionalAttributes.getCurrent()).length > 0) {
            metaData.add("Additional attributes: ", Utils.flatten((Object[]) this.m_SelectAdditionalAttributes.getCurrent(), ", "));
        }
        Classifier classifier2 = (Classifier) OptionUtils.shallowCopy(classifier);
        ((ClassifyTab) getOwner()).logMessage("Using '" + dataContainer.getID() + "/" + data.relationName() + "' to train " + OptionUtils.getCommandLine(classifier));
        classifier2.buildClassifier(data);
        addObjectSize(metaData, "Model size", classifier2);
        ((ClassifyTab) getOwner()).logMessage("Using '" + dataContainer3.getID() + "/" + data3.relationName() + "' to evaluate " + OptionUtils.getCommandLine(classifier));
        Evaluation evaluation = new Evaluation(data);
        evaluation.setDiscardPredictions(isSelected);
        TestingHelper.evaluateModel(classifier2, data2, evaluation, getTestingUpdateInterval(), new TestingHelper.TestingUpdateListener() { // from class: adams.gui.tools.wekainvestigator.tab.classifytab.evaluation.TrainValidateTestSet.1
            @Override // weka.classifiers.TestingHelper.TestingUpdateListener
            public void testingUpdateRequested(Instances instances, int i, int i2) {
                TrainValidateTestSet.this.getOwner().logMessage("Used " + i + "/" + i2 + " of '" + dataContainer2.getID() + "/" + data2.relationName() + "' to validate " + OptionUtils.getCommandLine(classifier));
            }
        });
        resultItem.update(evaluation, classifier2, metaData, null, transferAdditionalAttributes(this.m_SelectAdditionalAttributes, data2));
        Evaluation evaluation2 = new Evaluation(data);
        evaluation2.setDiscardPredictions(isSelected);
        TestingHelper.evaluateModel(classifier2, data3, evaluation2, getTestingUpdateInterval(), new TestingHelper.TestingUpdateListener() { // from class: adams.gui.tools.wekainvestigator.tab.classifytab.evaluation.TrainValidateTestSet.2
            @Override // weka.classifiers.TestingHelper.TestingUpdateListener
            public void testingUpdateRequested(Instances instances, int i, int i2) {
                TrainValidateTestSet.this.getOwner().logMessage("Used " + i + "/" + i2 + " of '" + dataContainer2.getID() + "/" + data2.relationName() + "' to test " + OptionUtils.getCommandLine(classifier));
            }
        });
        ResultItem resultItem2 = new ResultItem(resultItem.getTemplate(), resultItem.getHeader());
        resultItem2.setNameSuffix("Validation");
        resultItem2.update(evaluation2, classifier2, metaData, null, transferAdditionalAttributes(this.m_SelectAdditionalAttributes, data3));
        resultItem.addNestedItem("Validation", resultItem2);
    }

    @Override // adams.gui.tools.wekainvestigator.evaluation.AbstractEvaluation
    public void update() {
        if (getOwner() == null || getOwner().getOwner() == null) {
            return;
        }
        List<String> generateDatasetList = DatasetHelper.generateDatasetList(getOwner().getData());
        if (DatasetHelper.hasDataChanged(generateDatasetList, this.m_ModelDatasets)) {
            int indexOfDataset = DatasetHelper.indexOfDataset(getOwner().getData(), (String) this.m_ComboBoxTrain.getSelectedItem());
            this.m_ModelDatasets = new DefaultComboBoxModel<>(generateDatasetList.toArray(new String[generateDatasetList.size()]));
            this.m_ComboBoxTrain.setModel(this.m_ModelDatasets);
            if (indexOfDataset == -1 && this.m_ModelDatasets.getSize() > 0) {
                this.m_ComboBoxTrain.setSelectedIndex(0);
            } else if (indexOfDataset > -1) {
                this.m_ComboBoxTrain.setSelectedIndex(indexOfDataset);
            }
            int indexOfDataset2 = DatasetHelper.indexOfDataset(getOwner().getData(), (String) this.m_ComboBoxValidate.getSelectedItem());
            this.m_ModelDatasets = new DefaultComboBoxModel<>(generateDatasetList.toArray(new String[generateDatasetList.size()]));
            this.m_ComboBoxValidate.setModel(this.m_ModelDatasets);
            if (indexOfDataset2 == -1 && this.m_ModelDatasets.getSize() > 0) {
                this.m_ComboBoxValidate.setSelectedIndex(0);
            } else if (indexOfDataset2 > -1) {
                this.m_ComboBoxValidate.setSelectedIndex(indexOfDataset2);
            }
            int indexOfDataset3 = DatasetHelper.indexOfDataset(getOwner().getData(), (String) this.m_ComboBoxTest.getSelectedItem());
            this.m_ModelDatasets = new DefaultComboBoxModel<>(generateDatasetList.toArray(new String[generateDatasetList.size()]));
            this.m_ComboBoxTest.setModel(this.m_ModelDatasets);
            if (indexOfDataset3 == -1 && this.m_ModelDatasets.getSize() > 0) {
                this.m_ComboBoxTest.setSelectedIndex(0);
            } else if (indexOfDataset3 > -1) {
                this.m_ComboBoxTest.setSelectedIndex(indexOfDataset3);
            }
        }
        fillWithAttributeNames(this.m_SelectAdditionalAttributes, this.m_ComboBoxTest.getSelectedIndex());
        getOwner().updateButtons();
    }

    @Override // adams.gui.tools.wekainvestigator.evaluation.AbstractEvaluation
    public void activate(int i) {
        this.m_ComboBoxTrain.setSelectedIndex(i);
    }

    @Override // adams.gui.tools.wekainvestigator.evaluation.AbstractEvaluation
    public Map<String, Object> serialize(Set<AbstractInvestigatorTab.SerializationOption> set) {
        Map<String, Object> serialize = super.serialize(set);
        if (set.contains(AbstractInvestigatorTab.SerializationOption.GUI)) {
            serialize.put("train", Integer.valueOf(this.m_ComboBoxTrain.getSelectedIndex()));
            serialize.put(KEY_VALIDATE, Integer.valueOf(this.m_ComboBoxValidate.getSelectedIndex()));
            serialize.put("test", Integer.valueOf(this.m_ComboBoxTest.getSelectedIndex()));
        }
        if (set.contains(AbstractInvestigatorTab.SerializationOption.PARAMETERS)) {
            serialize.put("additional", this.m_SelectAdditionalAttributes.getCurrent());
            serialize.put("discardpredictions", Boolean.valueOf(this.m_CheckBoxDiscardPredictions.isSelected()));
        }
        return serialize;
    }

    @Override // adams.gui.tools.wekainvestigator.evaluation.AbstractEvaluation
    public void deserialize(Map<String, Object> map, MessageCollection messageCollection) {
        super.deserialize(map, messageCollection);
        if (map.containsKey("train")) {
            this.m_ComboBoxTrain.setSelectedIndex(((Number) map.get("train")).intValue());
        }
        if (map.containsKey(KEY_VALIDATE)) {
            this.m_ComboBoxValidate.setSelectedIndex(((Number) map.get(KEY_VALIDATE)).intValue());
        }
        if (map.containsKey("test")) {
            this.m_ComboBoxTest.setSelectedIndex(((Number) map.get("test")).intValue());
        }
        if (map.containsKey("additional")) {
            this.m_SelectAdditionalAttributes.setCurrent(listOrArray(map.get("additional")));
        }
        if (map.containsKey("discardpredictions")) {
            this.m_CheckBoxDiscardPredictions.setSelected(((Boolean) map.get("discardpredictions")).booleanValue());
        }
    }
}
