/*
 * Decompiled with CFR 0.152.
 */
package adams.flow.transformer.indexedsplitsrunspredictions;

import adams.core.MessageCollection;
import adams.core.ObjectCopyHelper;
import adams.core.QuickInfoHelper;
import adams.core.Utils;
import adams.core.option.OptionHandler;
import adams.data.indexedsplits.IndexedSplit;
import adams.data.indexedsplits.IndexedSplits;
import adams.data.indexedsplits.IndexedSplitsRun;
import adams.data.indexedsplits.IndexedSplitsRuns;
import adams.data.indexedsplits.SplitIndices;
import adams.data.spreadsheet.DefaultSpreadSheet;
import adams.data.spreadsheet.HeaderRow;
import adams.data.spreadsheet.SpreadSheet;
import adams.flow.core.Actor;
import adams.flow.core.CallableActorHelper;
import adams.flow.core.CallableActorReference;
import adams.flow.source.WekaClassifierSetup;
import adams.flow.transformer.indexedsplitsrunspredictions.AbstractIndexedSplitsRunsPredictions;
import java.util.HashMap;
import java.util.Map;
import weka.classifiers.Classifier;
import weka.core.Instance;
import weka.core.Instances;

public class InstancesIndexedSplitsRunsPredictions
extends AbstractIndexedSplitsRunsPredictions<Instances> {
    private static final long serialVersionUID = 2315594121849810804L;
    protected String m_TrainSplitName;
    protected String m_TestSplitName;
    protected CallableActorReference m_Classifier;
    protected Classifier m_ManualClassifier;

    public String globalInfo() {
        return "Trains the referenced classifier on the training splits and generates predictions for the test splits.";
    }

    public void defineOptions() {
        super.defineOptions();
        this.m_OptionManager.add("train-split-name", "trainSplitName", (Object)"train");
        this.m_OptionManager.add("test-split-name", "testSplitName", (Object)"test");
        this.m_OptionManager.add("classifier", "classifier", (Object)new CallableActorReference(WekaClassifierSetup.class.getSimpleName()));
    }

    public void setTrainSplitName(String value) {
        this.m_TrainSplitName = value;
        this.reset();
    }

    public String getTrainSplitName() {
        return this.m_TrainSplitName;
    }

    public String trainSplitNameTipText() {
        return "The name of the split to use for training.";
    }

    public void setTestSplitName(String value) {
        this.m_TestSplitName = value;
        this.reset();
    }

    public String getTestSplitName() {
        return this.m_TestSplitName;
    }

    public String testSplitNameTipText() {
        return "The name of the split to use for testing, ie generating predictions.";
    }

    public void setClassifier(CallableActorReference value) {
        this.m_Classifier = value;
        this.reset();
    }

    public CallableActorReference getClassifier() {
        return this.m_Classifier;
    }

    public String classifierTipText() {
        return "The classifier to use on the splits.";
    }

    public void setManualClassifier(Classifier value) {
        this.m_ManualClassifier = value;
    }

    public Classifier getManualClassifier() {
        return this.m_ManualClassifier;
    }

    public boolean requiresFlowContext() {
        return true;
    }

    public Class accepts() {
        return Instances.class;
    }

    public String getQuickInfo() {
        String result = QuickInfoHelper.toString((OptionHandler)this, (String)"trainSplitName", (Object)this.m_TrainSplitName, (String)"train: ");
        result = result + QuickInfoHelper.toString((OptionHandler)this, (String)"testSplitName", (Object)this.m_TestSplitName, (String)", test: ");
        result = result + QuickInfoHelper.toString((OptionHandler)this, (String)"classifier", (Object)this.m_Classifier, (String)", classifier: ");
        return result;
    }

    public String check(Instances data, IndexedSplitsRuns runs) {
        String result = super.check((Object)data, runs);
        if (result == null) {
            if (data.classIndex() == -1) {
                result = "No class attribute set!";
            } else if (!data.classAttribute().isNominal() && !data.classAttribute().isNumeric()) {
                result = "Class attribute can only be nominal or numeric!";
            }
        }
        return result;
    }

    protected Classifier getClassifierInstance(MessageCollection errors) {
        if (this.m_ManualClassifier != null) {
            return this.m_ManualClassifier;
        }
        Classifier result = (Classifier)CallableActorHelper.getSetup(Classifier.class, (CallableActorReference)this.m_Classifier, (Actor)this.m_FlowContext, (MessageCollection)errors);
        if (result == null && !errors.isEmpty()) {
            this.getLogger().severe(errors.toString());
        }
        return result;
    }

    protected Map<String, Instances> applyIndexedSplit(IndexedSplit indexedSplit, Instances data) {
        HashMap<String, Instances> result = new HashMap<String, Instances>();
        for (String key : indexedSplit.getIndices().keySet()) {
            SplitIndices splitIndices = (SplitIndices)indexedSplit.getIndices().get(key);
            int[] indices = splitIndices.getIndices();
            Instances split = new Instances(data, indices.length);
            for (int i = 0; i < indices.length; ++i) {
                Instance inst = (Instance)data.instance(indices[i]).copy();
                split.add(inst);
            }
            result.put(key, split);
        }
        return result;
    }

    protected SpreadSheet doGenerate(Instances data, IndexedSplitsRuns runs, MessageCollection errors) {
        int i;
        Classifier template = this.getClassifierInstance(errors);
        if (template == null) {
            return null;
        }
        DefaultSpreadSheet result = new DefaultSpreadSheet();
        HeaderRow row = result.getHeaderRow();
        row.addCell("r").setContentAsString("Run");
        row.addCell("s").setContentAsString("Split");
        row.addCell("n").setContentAsString("Name");
        row.addCell("i").setContentAsString("Index");
        row.addCell("p").setContentAsString("Prediction");
        boolean nominal = false;
        if (data.classAttribute().isNominal()) {
            nominal = true;
            for (i = 0; i < data.classAttribute().numValues(); ++i) {
                row.addCell("d" + i).setContentAsString("Class-" + data.classAttribute().value(i));
            }
        }
        try {
            for (int run = 0; run < runs.size(); ++run) {
                if (this.m_Stopped) {
                    return null;
                }
                IndexedSplitsRun indexedSplitsRun = (IndexedSplitsRun)runs.get(run);
                IndexedSplits indexedSplits = indexedSplitsRun.getSplits();
                for (int split = 0; split < indexedSplits.size(); ++split) {
                    if (this.m_Stopped) {
                        return null;
                    }
                    Map<String, Instances> namedSplits = this.applyIndexedSplit((IndexedSplit)indexedSplits.get(split), data);
                    if (!namedSplits.containsKey(this.m_TrainSplitName)) {
                        throw new IllegalArgumentException("Failed to locate train split '" + this.m_TrainSplitName + "' (run=" + run + ", split=" + split + "), available: " + Utils.flatten((Object[])namedSplits.keySet().toArray(), (String)","));
                    }
                    if (!namedSplits.containsKey(this.m_TestSplitName)) {
                        throw new IllegalArgumentException("Failed to locate test split '" + this.m_TestSplitName + "' (run=" + run + ", split=" + split + "), available: " + Utils.flatten((Object[])namedSplits.keySet().toArray(), (String)","));
                    }
                    Instances train = namedSplits.get(this.m_TrainSplitName);
                    Instances test = namedSplits.get(this.m_TestSplitName);
                    Classifier classifier = (Classifier)ObjectCopyHelper.copyObject((Object)template);
                    classifier.buildClassifier(train);
                    for (int index = 0; index < test.numInstances(); ++index) {
                        row = result.addRow();
                        double pred = classifier.classifyInstance(test.instance(index));
                        double[] dist = new double[]{};
                        if (nominal) {
                            dist = classifier.distributionForInstance(test.instance(index));
                        }
                        row.addCell("r").setContent(Integer.valueOf(run));
                        row.addCell("s").setContent(Integer.valueOf(split));
                        row.addCell("n").setContentAsString(this.m_TestSplitName);
                        row.addCell("i").setContent(Integer.valueOf(index));
                        if (nominal) {
                            row.addCell("p").setContentAsString(test.classAttribute().value((int)pred));
                            for (i = 0; i < dist.length; ++i) {
                                row.addCell("d" + i).setContent(Double.valueOf(dist[i]));
                            }
                            continue;
                        }
                        row.addCell("p").setContent(Double.valueOf(pred));
                    }
                }
            }
        }
        catch (Exception e) {
            errors.add("Failed to generate predictions!", (Throwable)e);
        }
        if (errors.isEmpty()) {
            return result;
        }
        return null;
    }
}

