/*
 * Decompiled with CFR 0.152.
 */
package weka.knowledgeflow.steps;

import java.io.BufferedInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.ObjectInputStream;
import java.util.ArrayList;
import java.util.List;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.UpdateableBatchProcessor;
import weka.classifiers.UpdateableClassifier;
import weka.classifiers.misc.InputMappedClassifier;
import weka.core.Drawable;
import weka.core.EnvironmentHandler;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.OptionHandler;
import weka.core.OptionMetadata;
import weka.core.Utils;
import weka.core.WekaException;
import weka.gui.FilePropertyMetadata;
import weka.gui.ProgrammaticProperty;
import weka.knowledgeflow.Data;
import weka.knowledgeflow.steps.KFStep;
import weka.knowledgeflow.steps.PairedDataHelper;
import weka.knowledgeflow.steps.WekaAlgorithmWrapper;

@KFStep(name="Classifier", category="Classifiers", toolTipText="Weka classifier wrapper", iconPath="")
public class Classifier
extends WekaAlgorithmWrapper
implements PairedDataHelper.PairedProcessor<weka.classifiers.Classifier> {
    private static final long serialVersionUID = 8326706942962123155L;
    protected weka.classifiers.Classifier m_classifierTemplate;
    protected weka.classifiers.Classifier m_trainedClassifier;
    protected Instances m_trainedClassifierHeader;
    protected File m_loadModelFileName = new File("");
    protected boolean m_resetIncrementalClassifier;
    protected boolean m_updateIncrementalClassifier = true;
    protected boolean m_streaming;
    protected boolean m_classifierIsIncremental;
    protected transient PairedDataHelper<weka.classifiers.Classifier> m_trainTestHelper;
    protected Data m_incrementalData = new Data("incrementalClassifier");
    protected boolean m_isReset;

    @Override
    public Class getWrappedAlgorithmClass() {
        return weka.classifiers.Classifier.class;
    }

    @Override
    public void setWrappedAlgorithm(Object algo) {
        super.setWrappedAlgorithm(algo);
        this.m_defaultIconPath = "weka/gui/knowledgeflow/icons/DefaultClassifier.gif";
    }

    public weka.classifiers.Classifier getClassifier() {
        return (weka.classifiers.Classifier)this.getWrappedAlgorithm();
    }

    @ProgrammaticProperty
    public void setClassifier(weka.classifiers.Classifier classifier) {
        this.setWrappedAlgorithm(classifier);
    }

    @Override
    public void stepInit() throws WekaException {
        try {
            this.m_trainedClassifier = null;
            this.m_trainTestHelper = null;
            this.m_incrementalData = new Data("incrementalClassifier");
            this.m_classifierTemplate = AbstractClassifier.makeCopy((weka.classifiers.Classifier)this.getWrappedAlgorithm());
            if (this.m_classifierTemplate instanceof EnvironmentHandler) {
                ((EnvironmentHandler)((Object)this.m_classifierTemplate)).setEnvironment(this.getStepManager().getExecutionEnvironment().getEnvironmentVariables());
            }
        }
        catch (Exception ex) {
            throw new WekaException(ex);
        }
        if (this.getStepManager().numIncomingConnectionsOfType("trainingSet") > 0) {
            this.m_trainTestHelper = new PairedDataHelper(this, this, "trainingSet", this.getStepManager().numIncomingConnectionsOfType("testSet") > 0 ? "testSet" : null);
        }
        this.m_isReset = true;
        this.m_classifierIsIncremental = this.m_classifierTemplate instanceof UpdateableClassifier;
        if (this.getLoadClassifierFileName() != null && this.getLoadClassifierFileName().toString().length() > 0 && this.getStepManager().numIncomingConnectionsOfType("trainingSet") == 0) {
            String resolvedFileName = this.getStepManager().environmentSubstitute(this.getLoadClassifierFileName().toString());
            try {
                this.getStepManager().logBasic("Loading classifier: " + resolvedFileName);
                this.loadModel(resolvedFileName);
            }
            catch (Exception ex) {
                throw new WekaException(ex);
            }
        }
        if (this.m_trainedClassifier != null && this.getStepManager().numIncomingConnectionsOfType("instance") > 0 && !this.m_classifierIsIncremental) {
            this.getStepManager().logWarning("Loaded classifier is not an incremental one - will only be able to evaluate, and not update, on the incoming instance stream.");
        }
    }

    public File getLoadClassifierFileName() {
        return this.m_loadModelFileName;
    }

    @OptionMetadata(displayName="Classifier model to load", description="Optional Path to a classifier to load at execution time (only applies when using testSet or instance connections)")
    @FilePropertyMetadata(fileChooserDialogType=0, directoriesOnly=false)
    public void setLoadClassifierFileName(File filename) {
        this.m_loadModelFileName = filename;
    }

    public boolean getResetIncrementalClassifier() {
        return this.m_resetIncrementalClassifier;
    }

    @OptionMetadata(displayName="Reset incremental classifier", description="Reset classifier (if it is incremental) at the start of the incoming instance stream")
    public void setResetIncrementalClassifier(boolean reset) {
        this.m_resetIncrementalClassifier = reset;
    }

    public boolean getUpdateIncrementalClassifier() {
        return this.m_updateIncrementalClassifier;
    }

    @OptionMetadata(displayName="Update incremental classifier", description=" Update an incremental classifier on incoming instance stream")
    public void setUpdateIncrementalClassifier(boolean update) {
        this.m_updateIncrementalClassifier = true;
    }

    @Override
    public void processIncoming(Data data) throws WekaException {
        try {
            if (this.m_isReset) {
                this.m_isReset = false;
                Instances incomingStructure = null;
                incomingStructure = data.getConnectionName().equals("instance") ? new Instances(((Instance)data.getPayloadElement("instance")).dataset(), 0) : (Instances)data.getPayloadElement(data.getConnectionName());
                if (incomingStructure.classAttribute() == null) {
                    this.getStepManager().logWarning("No class index is set in the data - using last attribute as class");
                    incomingStructure.setClassIndex(incomingStructure.numAttributes() - 1);
                }
                if (data.getConnectionName().equals("instance")) {
                    this.m_streaming = true;
                    if (this.m_trainedClassifier == null) {
                        this.m_trainedClassifier = AbstractClassifier.makeCopy(this.m_classifierTemplate);
                        this.getStepManager().logBasic("Initialising incremental classifier");
                        this.m_trainedClassifier.buildClassifier(incomingStructure);
                        this.m_trainedClassifierHeader = incomingStructure;
                    } else if (this.m_resetIncrementalClassifier && this.m_classifierIsIncremental) {
                        this.m_trainedClassifier = AbstractClassifier.makeCopy(this.m_classifierTemplate);
                        this.m_trainedClassifierHeader = incomingStructure;
                        this.getStepManager().logBasic("Resetting incremental classifier");
                        this.m_trainedClassifier.buildClassifier(this.m_trainedClassifierHeader);
                    }
                    this.getStepManager().logBasic(this.m_updateIncrementalClassifier && this.m_classifierIsIncremental ? "Training incrementally" : "Predicting incrementally");
                } else if (data.getConnectionName().equals("trainingSet")) {
                    this.m_trainedClassifierHeader = incomingStructure;
                }
                if (this.m_trainedClassifierHeader != null && !incomingStructure.equalHeaders(this.m_trainedClassifierHeader) && !(this.m_trainedClassifier instanceof InputMappedClassifier)) {
                    throw new WekaException("Structure of incoming data does not match that of the trained classifier");
                }
            }
            if (this.m_streaming) {
                this.processStreaming(data);
            } else if (this.m_trainTestHelper != null) {
                this.m_trainTestHelper.process(data);
            } else {
                this.processOnlyTestSet(data);
            }
        }
        catch (Exception ex) {
            throw new WekaException(ex);
        }
    }

    @Override
    public weka.classifiers.Classifier processPrimary(Integer setNum, Integer maxSetNum, Data data, PairedDataHelper<weka.classifiers.Classifier> helper) throws WekaException {
        Instances trainingData = (Instances)data.getPrimaryPayload();
        try {
            weka.classifiers.Classifier classifier = AbstractClassifier.makeCopy(this.m_classifierTemplate);
            String classifierDesc = classifier.getClass().getCanonicalName();
            classifierDesc = classifierDesc.substring(classifierDesc.lastIndexOf(".") + 1);
            if (classifier instanceof OptionHandler) {
                String optsString = Utils.joinOptions(((OptionHandler)((Object)classifier)).getOptions());
                classifierDesc = classifierDesc + " " + optsString;
            }
            if (classifier instanceof EnvironmentHandler) {
                ((EnvironmentHandler)((Object)classifier)).setEnvironment(this.getStepManager().getExecutionEnvironment().getEnvironmentVariables());
            }
            helper.addIndexedValueToNamedStore("trainingSplits", setNum, trainingData);
            if (!this.isStopRequested()) {
                this.getStepManager().logBasic("Building " + classifierDesc + " on " + trainingData.relationName() + " for fold/set " + setNum + " out of " + maxSetNum);
                if (maxSetNum == 1) {
                    this.m_trainedClassifier = classifier;
                }
                classifier.buildClassifier(trainingData);
                this.getStepManager().logDetailed("Finished building " + classifierDesc + "on " + trainingData.relationName() + " for fold/set " + setNum + " out of " + maxSetNum);
                this.outputTextData(classifier, setNum);
                this.outputGraphData(classifier, setNum);
            }
            return classifier;
        }
        catch (Exception ex) {
            throw new WekaException(ex);
        }
    }

    @Override
    public void processSecondary(Integer setNum, Integer maxSetNum, Data data, PairedDataHelper<weka.classifiers.Classifier> helper) throws WekaException {
        weka.classifiers.Classifier classifier = helper.getIndexedPrimaryResult(setNum);
        Instances testSplit = (Instances)data.getPrimaryPayload();
        Instances trainingSplit = (Instances)helper.getIndexedValueFromNamedStore("trainingSplits", setNum);
        this.getStepManager().logBasic("Dispatching model for set " + setNum + " out of " + maxSetNum + " to output");
        Data batchClassifier = new Data("batchClassifier", classifier);
        batchClassifier.setPayloadElement("aux_trainingSet", trainingSplit);
        batchClassifier.setPayloadElement("aux_testsSet", testSplit);
        batchClassifier.setPayloadElement("aux_set_num", setNum);
        batchClassifier.setPayloadElement("aux_max_set_num", maxSetNum);
        batchClassifier.setPayloadElement("aux_label", this.getName());
        this.getStepManager().outputData(batchClassifier);
    }

    protected void processOnlyTestSet(Data data) throws WekaException {
        try {
            weka.classifiers.Classifier tempToTest = AbstractClassifier.makeCopy(this.m_trainedClassifier);
            Data batchClassifier = new Data("batchClassifier");
            batchClassifier.setPayloadElement("batchClassifier", tempToTest);
            batchClassifier.setPayloadElement("aux_testsSet", data.getPayloadElement("testSet"));
            batchClassifier.setPayloadElement("aux_set_num", data.getPayloadElement("aux_set_num", 1));
            batchClassifier.setPayloadElement("aux_max_set_num", data.getPayloadElement("aux_max_set_num", 1));
            batchClassifier.setPayloadElement("aux_label", this.getName());
            this.getStepManager().outputData(batchClassifier);
            if (this.isStopRequested()) {
                this.getStepManager().interrupted();
            } else {
                this.getStepManager().finished();
            }
        }
        catch (Exception ex) {
            throw new WekaException(ex);
        }
    }

    protected void processStreaming(Data data) throws WekaException {
        if (this.isStopRequested()) {
            return;
        }
        Instance inst = (Instance)data.getPayloadElement("instance");
        if (this.getStepManager().isStreamFinished(data)) {
            if (this.m_trainedClassifier instanceof UpdateableBatchProcessor) {
                try {
                    ((UpdateableBatchProcessor)((Object)this.m_trainedClassifier)).batchFinished();
                }
                catch (Exception ex) {
                    throw new WekaException(ex);
                }
            }
            this.m_incrementalData.setPayloadElement("incrementalClassifier", this.m_trainedClassifier);
            this.m_incrementalData.setPayloadElement("aux_testInstance", null);
            this.outputTextData(this.m_trainedClassifier, -1);
            this.outputGraphData(this.m_trainedClassifier, 0);
            if (!this.isStopRequested()) {
                this.getStepManager().throughputFinished(this.m_incrementalData);
            }
            return;
        }
        this.m_incrementalData.setPayloadElement("aux_testInstance", inst);
        this.m_incrementalData.setPayloadElement("incrementalClassifier", this.m_trainedClassifier);
        this.getStepManager().outputData(this.m_incrementalData.getConnectionName(), this.m_incrementalData);
        this.getStepManager().throughputUpdateStart();
        if (this.m_classifierIsIncremental && this.m_updateIncrementalClassifier && !inst.classIsMissing()) {
            try {
                ((UpdateableClassifier)((Object)this.m_trainedClassifier)).updateClassifier(inst);
            }
            catch (Exception ex) {
                throw new WekaException(ex);
            }
        }
        this.getStepManager().throughputUpdateEnd();
    }

    protected void outputTextData(weka.classifiers.Classifier classifier, int setNum) throws WekaException {
        if (this.getStepManager().numOutgoingConnectionsOfType("text") == 0) {
            return;
        }
        Data textData = new Data("text");
        String modelString = classifier.toString();
        String titleString = classifier.getClass().getName();
        titleString = titleString.substring(titleString.lastIndexOf(46) + 1, titleString.length());
        modelString = "=== Classifier model ===\n\nScheme:   " + titleString + "\n" + "Relation: " + this.m_trainedClassifierHeader.relationName() + "\n\n" + modelString;
        titleString = "Model: " + titleString;
        textData.setPayloadElement("text", modelString);
        textData.setPayloadElement("aux_textTitle", titleString);
        if (setNum != -1) {
            textData.setPayloadElement("aux_set_num", setNum);
        }
        this.getStepManager().outputData(textData);
    }

    protected void outputGraphData(weka.classifiers.Classifier classifier, int setNum) throws WekaException {
        if (classifier instanceof Drawable) {
            if (this.getStepManager().numOutgoingConnectionsOfType("graph") == 0) {
                return;
            }
            try {
                String graphString = ((Drawable)((Object)classifier)).graph();
                int graphType = ((Drawable)((Object)classifier)).graphType();
                String grphTitle = classifier.getClass().getCanonicalName();
                grphTitle = grphTitle.substring(grphTitle.lastIndexOf(46) + 1, grphTitle.length());
                grphTitle = "Set " + setNum + " (" + this.m_trainedClassifierHeader.relationName() + ") " + grphTitle;
                Data graphData = new Data("graph");
                graphData.setPayloadElement("graph", graphString);
                graphData.setPayloadElement("graph_title", grphTitle);
                graphData.setPayloadElement("graph_type", graphType);
                this.getStepManager().outputData(graphData);
            }
            catch (Exception ex) {
                throw new WekaException(ex);
            }
        }
    }

    @Override
    public List<String> getIncomingConnectionTypes() {
        ArrayList<String> result = new ArrayList<String>();
        int numTraining = this.getStepManager().numIncomingConnectionsOfType("trainingSet");
        int numTesting = this.getStepManager().numIncomingConnectionsOfType("testSet");
        int numInstance = this.getStepManager().numIncomingConnectionsOfType("instance");
        if (numTraining == 0 && numTesting == 0) {
            result.add("instance");
        }
        if (numInstance == 0 && numTraining == 0) {
            result.add("trainingSet");
        }
        if (numInstance == 0 && numTesting == 0) {
            result.add("testSet");
        }
        return result;
    }

    @Override
    public List<String> getOutgoingConnectionTypes() {
        ArrayList<String> result = new ArrayList<String>();
        if (this.getStepManager().numIncomingConnections() > 0) {
            int numTraining = this.getStepManager().numIncomingConnectionsOfType("trainingSet");
            int numTesting = this.getStepManager().numIncomingConnectionsOfType("testSet");
            int numInstance = this.getStepManager().numIncomingConnectionsOfType("instance");
            if (numInstance > 0) {
                result.add("incrementalClassifier");
            } else if (numTraining > 0 || numTesting > 0) {
                result.add("batchClassifier");
            }
            result.add("text");
            if (this.getClassifier() instanceof Drawable && numTraining > 0) {
                result.add("graph");
            }
        }
        result.add("info");
        return result;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    protected void loadModel(String filePath) throws Exception {
        ObjectInputStream is = null;
        try {
            is = new ObjectInputStream(new BufferedInputStream(new FileInputStream(new File(filePath))));
            this.m_trainedClassifier = (weka.classifiers.Classifier)is.readObject();
            if (!this.m_trainedClassifier.getClass().getCanonicalName().equals(this.getClassifier().getClass().getCanonicalName())) {
                throw new Exception("The loaded model '" + this.m_trainedClassifier.getClass().getCanonicalName() + "' is not a '" + this.getClassifier().getClass().getCanonicalName() + "'");
            }
            try {
                this.m_trainedClassifierHeader = (Instances)is.readObject();
            }
            catch (Exception ex) {
                this.getStepManager().logWarning("Model file '" + filePath + "' does not seem to contain an Instances header");
            }
        }
        finally {
            if (is != null) {
                is.close();
            }
        }
    }
}

