/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.functions;

import adams.core.Range;
import adams.core.Utils;
import adams.core.Variables;
import adams.core.base.BaseKeyValuePair;
import adams.core.base.BaseObject;
import adams.core.base.BaseRegExp;
import adams.core.base.BaseString;
import adams.core.io.FileUtils;
import adams.core.io.PlaceholderDirectory;
import adams.core.io.PlaceholderFile;
import adams.core.io.TempUtils;
import adams.core.io.lister.LocalDirectoryLister;
import adams.core.logging.LoggingLevel;
import adams.core.logging.LoggingObject;
import adams.flow.core.RunnableWithLogging;
import adams.ml.cntk.CNTK;
import adams.ml.cntk.CNTKPredictionWrapper;
import adams.ml.cntk.DeviceType;
import adams.ml.cntk.predictionpostprocessor.Normalize;
import com.github.fracpete.processoutput4j.core.StreamingProcessOutputType;
import com.github.fracpete.processoutput4j.core.StreamingProcessOwner;
import com.github.fracpete.processoutput4j.output.StreamingProcessOutput;
import gnu.trove.list.array.TFloatArrayList;
import java.io.File;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.List;
import java.util.Vector;
import weka.classifiers.AbstractClassifier;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.WekaOptionUtils;
import weka.core.converters.CNTKSaver;
import weka.core.converters.ConverterUtils;
import weka.core.converters.Saver;

public class CNTKBrainscriptModel
extends AbstractClassifier
implements StreamingProcessOwner {
    private static final long serialVersionUID = 7732345053235983381L;
    protected static String SCRIPT = "script";
    protected static String TRAINFILE = "train-file";
    protected static String MODELDIR = "model-directory";
    protected static String MODELEXT = "model-extensions";
    protected static String MODEL = "model";
    protected static String DEVICETYPE = "device-type";
    protected static String GPUDEVICEID = "gpu-device-id";
    protected static String INPUTS = "inputs";
    protected static String INPUTNAMES = "input-names";
    protected static String CLASSNAME = "class-name";
    protected static String OUTPUTNAME = "output-name";
    protected static String VARIABLES = "variables";
    protected PlaceholderFile m_Script = this.getDefaultScript();
    protected PlaceholderFile m_TmpScript;
    protected PlaceholderFile m_TrainFile = this.getDefaultTrainFile();
    protected PlaceholderDirectory m_ModelDirectory = this.getDefaultModelDirectory();
    protected PlaceholderFile m_Model = this.getDefaultModel();
    protected String m_ModelExtension = this.getDefaultModelExtension();
    protected DeviceType m_DeviceType = this.getDefaultDeviceType();
    protected long m_GPUDeviceID = this.getDefaultGPUDeviceID();
    protected BaseKeyValuePair[] m_Variables = this.getDefaultVariables();
    protected CNTKPredictionWrapper m_Wrapper = new CNTKPredictionWrapper();
    protected Normalize m_Normalize = new Normalize();
    protected LocalDirectoryLister m_Lister;
    protected StreamingProcessOutput m_ProcessOutput;
    protected RunnableWithLogging m_Monitor;
    protected IllegalStateException m_ExecutionFailure;

    public String globalInfo() {
        return "Builds a CNTK model using the supplied Brainscript and then applies the model to the data for making predictions.";
    }

    public Enumeration listOptions() {
        Vector result = new Vector();
        WekaOptionUtils.addOption(result, (String)this.scriptTipText(), (String)("" + this.getDefaultScript()), (String)SCRIPT);
        WekaOptionUtils.addOption(result, (String)this.trainFileTipText(), (String)("" + this.getDefaultTrainFile()), (String)TRAINFILE);
        WekaOptionUtils.addOption(result, (String)this.modelDirectoryTipText(), (String)("" + this.getDefaultModelDirectory()), (String)MODELDIR);
        WekaOptionUtils.addOption(result, (String)this.modelExtensionTipText(), (String)("" + this.getDefaultModelExtension()), (String)MODELEXT);
        WekaOptionUtils.addOption(result, (String)this.modelTipText(), (String)("" + this.getDefaultModel()), (String)MODEL);
        WekaOptionUtils.addOption(result, (String)this.deviceTypeTipText(), (String)("" + this.getDefaultDeviceType()), (String)DEVICETYPE);
        WekaOptionUtils.addOption(result, (String)this.GPUDeviceIDTipText(), (String)("" + this.getDefaultGPUDeviceID()), (String)GPUDEVICEID);
        WekaOptionUtils.addOption(result, (String)this.inputsTipText(), (String)Utils.arrayToString((Object)this.getDefaultInputs()), (String)INPUTS);
        WekaOptionUtils.addOption(result, (String)this.inputNamesTipText(), (String)Utils.arrayToString((Object)this.getDefaultInputNames()), (String)INPUTNAMES);
        WekaOptionUtils.addOption(result, (String)this.classNameTipText(), (String)this.getDefaultClassName(), (String)CLASSNAME);
        WekaOptionUtils.addOption(result, (String)this.outputNameTipText(), (String)this.getDefaultOutputName(), (String)OUTPUTNAME);
        WekaOptionUtils.addOption(result, (String)this.variablesTipText(), (String)Utils.arrayToString((Object)this.getDefaultVariables()), (String)VARIABLES);
        WekaOptionUtils.add(result, (Enumeration)super.listOptions());
        return WekaOptionUtils.toEnumeration(result);
    }

    public void setOptions(String[] options) throws Exception {
        this.setScript(WekaOptionUtils.parse((String[])options, (String)SCRIPT, (PlaceholderFile)this.getDefaultScript()));
        this.setTrainFile(WekaOptionUtils.parse((String[])options, (String)TRAINFILE, (PlaceholderFile)this.getDefaultTrainFile()));
        this.setModelDirectory(WekaOptionUtils.parse((String[])options, (String)MODELDIR, (PlaceholderDirectory)this.getDefaultModelDirectory()));
        this.setModelExtension(WekaOptionUtils.parse((String[])options, (String)MODELEXT, (String)this.getDefaultModelExtension()));
        this.setModel(WekaOptionUtils.parse((String[])options, (String)MODEL, (PlaceholderFile)this.getDefaultModel()));
        this.setDeviceType((DeviceType)WekaOptionUtils.parse((String[])options, (String)DEVICETYPE, (Enum)this.getDefaultDeviceType()));
        this.setGPUDeviceID(WekaOptionUtils.parse((String[])options, (String)GPUDEVICEID, (long)this.getDefaultGPUDeviceID()));
        this.setInputs(WekaOptionUtils.parse((String[])options, (String)INPUTS, (Range[])this.getDefaultInputs()));
        this.setInputNames((BaseString[])WekaOptionUtils.parse((String[])options, (String)INPUTNAMES, (BaseObject[])this.getDefaultInputNames()));
        this.setClassName(WekaOptionUtils.parse((String[])options, (String)CLASSNAME, (String)this.getDefaultClassName()));
        this.setOutputName(WekaOptionUtils.parse((String[])options, (String)OUTPUTNAME, (String)this.getDefaultOutputName()));
        this.setVariables((BaseKeyValuePair[])WekaOptionUtils.parse((String[])options, (String)VARIABLES, (BaseObject[])this.getDefaultVariables()));
        super.setOptions(options);
    }

    public String[] getOptions() {
        ArrayList result = new ArrayList();
        WekaOptionUtils.add(result, (String)SCRIPT, (File)this.getScript());
        WekaOptionUtils.add(result, (String)TRAINFILE, (File)this.getTrainFile());
        WekaOptionUtils.add(result, (String)MODELDIR, (File)this.getModelDirectory());
        WekaOptionUtils.add(result, (String)MODELEXT, (String)this.getModelExtension());
        WekaOptionUtils.add(result, (String)MODEL, (File)this.getModel());
        WekaOptionUtils.add(result, (String)DEVICETYPE, (Enum)this.getDeviceType());
        if (this.getDeviceType() == DeviceType.GPU) {
            WekaOptionUtils.add(result, (String)GPUDEVICEID, (long)this.getGPUDeviceID());
        }
        WekaOptionUtils.add(result, (String)INPUTS, (Range[])this.getInputs());
        WekaOptionUtils.add(result, (String)INPUTNAMES, (BaseObject[])this.getInputNames());
        WekaOptionUtils.add(result, (String)CLASSNAME, (String)this.getClassName());
        WekaOptionUtils.add(result, (String)OUTPUTNAME, (String)this.getOutputName());
        WekaOptionUtils.add(result, (String)VARIABLES, (BaseObject[])this.getVariables());
        WekaOptionUtils.add(result, (String[])super.getOptions());
        return WekaOptionUtils.toArray(result);
    }

    public void setDebug(boolean debug) {
        super.setDebug(debug);
        this.m_Wrapper.setLoggingLevel(debug ? LoggingLevel.INFO : LoggingLevel.WARNING);
    }

    protected PlaceholderFile getDefaultScript() {
        return new PlaceholderFile();
    }

    public void setScript(PlaceholderFile value) {
        this.m_Script = value;
    }

    public PlaceholderFile getScript() {
        return this.m_Script;
    }

    public String scriptTipText() {
        return "The BrainScript to run.";
    }

    protected PlaceholderFile getDefaultTrainFile() {
        return new PlaceholderFile();
    }

    public void setTrainFile(PlaceholderFile value) {
        this.m_TrainFile = value;
    }

    public PlaceholderFile getTrainFile() {
        return this.m_TrainFile;
    }

    public String trainFileTipText() {
        return "The training file used by the Brainscript; the training instances will get saved to that file.";
    }

    protected PlaceholderDirectory getDefaultModelDirectory() {
        return new PlaceholderDirectory();
    }

    public void setModelDirectory(PlaceholderDirectory value) {
        this.m_ModelDirectory = value;
    }

    public PlaceholderDirectory getModelDirectory() {
        return this.m_ModelDirectory;
    }

    public String modelDirectoryTipText() {
        return "The directory containing the models, temp models and checkpoint files.";
    }

    protected PlaceholderFile getDefaultModel() {
        return new PlaceholderFile();
    }

    public void setModel(PlaceholderFile value) {
        this.m_Model = value;
    }

    public PlaceholderFile getModel() {
        return this.m_Model;
    }

    public String modelTipText() {
        return "The prebuilt CNTK model to use.";
    }

    protected String getDefaultModelExtension() {
        return ".cmf";
    }

    public void setModelExtension(String value) {
        this.m_ModelExtension = value;
    }

    public String getModelExtension() {
        return this.m_ModelExtension;
    }

    public String modelExtensionTipText() {
        return "The file extension used by the models (incl dot).";
    }

    protected DeviceType getDefaultDeviceType() {
        return DeviceType.DEFAULT;
    }

    public void setDeviceType(DeviceType value) {
        this.m_DeviceType = value;
    }

    public DeviceType getDeviceType() {
        return this.m_DeviceType;
    }

    public String deviceTypeTipText() {
        return "The device type to use.";
    }

    protected long getDefaultGPUDeviceID() {
        return 0L;
    }

    public void setGPUDeviceID(long value) {
        this.m_GPUDeviceID = value;
    }

    public long getGPUDeviceID() {
        return this.m_GPUDeviceID;
    }

    public String GPUDeviceIDTipText() {
        return "The GPU device ID.";
    }

    protected Range[] getDefaultInputs() {
        return new Range[0];
    }

    public void setInputs(Range[] value) {
        this.m_Wrapper.setInputs(value);
    }

    public Range[] getInputs() {
        return this.m_Wrapper.getInputs();
    }

    public String inputsTipText() {
        return this.m_Wrapper.inputsTipText();
    }

    protected BaseString[] getDefaultInputNames() {
        return new BaseString[0];
    }

    public void setInputNames(BaseString[] value) {
        this.m_Wrapper.setInputNames(value);
    }

    public BaseString[] getInputNames() {
        return this.m_Wrapper.getInputNames();
    }

    public String inputNamesTipText() {
        return this.m_Wrapper.inputNamesTipText();
    }

    protected String getDefaultClassName() {
        return "";
    }

    public void setClassName(String value) {
        this.m_Wrapper.setClassName(value);
    }

    public String getClassName() {
        return this.m_Wrapper.getClassName();
    }

    public String classNameTipText() {
        return this.m_Wrapper.classNameTipText();
    }

    protected String getDefaultOutputName() {
        return "";
    }

    public void setOutputName(String value) {
        this.m_Wrapper.setOutputName(value);
    }

    public String getOutputName() {
        return this.m_Wrapper.getOutputName();
    }

    public String outputNameTipText() {
        return this.m_Wrapper.outputNameTipText();
    }

    protected BaseKeyValuePair[] getDefaultVariables() {
        return new BaseKeyValuePair[0];
    }

    public void setVariables(BaseKeyValuePair[] value) {
        this.m_Variables = value;
    }

    public BaseKeyValuePair[] getVariables() {
        return this.m_Variables;
    }

    public String variablesTipText() {
        return "The key-value pairs representing variables and their associated values to be replaced in the script.";
    }

    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAll();
        result.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        result.disableAllClasses();
        result.enable(Capabilities.Capability.NOMINAL_CLASS);
        result.enable(Capabilities.Capability.NUMERIC_CLASS);
        result.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        return result;
    }

    public StreamingProcessOutputType getOutputType() {
        return StreamingProcessOutputType.BOTH;
    }

    public void processOutput(String line, boolean stdout) {
        if (stdout) {
            System.out.println(line);
        } else {
            System.err.println(line);
        }
        if (this.m_Lister != null) {
            String[] files;
            for (String file : files = this.m_Lister.list()) {
                FileUtils.delete((String)file);
            }
        }
    }

    protected String preprocessScript() {
        String result = this.m_Script.getAbsolutePath();
        this.m_TmpScript = null;
        if (this.m_Variables.length > 0) {
            Variables vars = new Variables();
            for (BaseKeyValuePair var : this.m_Variables) {
                vars.set(var.getPairKey(), var.getPairValue());
            }
            List lines = FileUtils.loadFromFile((File)this.m_Script);
            if (lines != null) {
                result = TempUtils.createTempFile((String)"adams-cntk-weka-bs-", (String)".bs").getAbsolutePath();
                String content = Utils.flatten((List)lines, (String)"\n");
                String msg = FileUtils.writeToFileMsg((String)result, (Object)(content = vars.expand(content)), (boolean)false, null);
                if (msg != null) {
                    throw new IllegalStateException("Failed to write expanded script!\n" + msg);
                }
                this.m_TmpScript = new PlaceholderFile(result);
            }
        }
        return result;
    }

    protected void buildModel(Instances data) throws Exception {
        this.getCapabilities().testWithFail(data);
        if (this.getDebug()) {
            System.out.println("Saving data to: " + this.m_TrainFile);
        }
        CNTKSaver saver = new CNTKSaver();
        saver.setSuppressMissing(true);
        saver.setUseSparseFormat(false);
        saver.setInputs(this.getInputs());
        saver.setInputNames(this.getInputNames());
        saver.setDestination(this.m_TrainFile.getAbsoluteFile());
        ConverterUtils.DataSink.write((Saver)saver, (Instances)data);
        if (this.getDebug()) {
            System.out.println("Saved data to: " + this.m_TrainFile);
        }
        String script = this.preprocessScript();
        String cmd = CNTK.getBinary().getAbsolutePath();
        final String fCmd = cmd = cmd + " configFile=" + script;
        if (this.getDebug()) {
            System.out.println("Command: " + cmd);
        }
        this.m_Lister = new LocalDirectoryLister();
        this.m_Lister.setWatchDir(this.m_ModelDirectory.getAbsolutePath());
        this.m_Lister.setRegExp(new BaseRegExp(".*\\" + this.m_ModelExtension + "\\.([0-9]+|ckp)$"));
        this.m_Lister.setRecursive(false);
        this.m_Lister.setListDirs(false);
        this.m_Lister.setListFiles(true);
        this.m_ExecutionFailure = null;
        this.m_ProcessOutput = new StreamingProcessOutput((StreamingProcessOwner)this);
        this.m_Monitor = new RunnableWithLogging(){
            private static final long serialVersionUID = -4475355379511760429L;
            protected Process m_Process;

            protected void doRun() {
                try {
                    this.m_Process = Runtime.getRuntime().exec(fCmd, null, null);
                    CNTKBrainscriptModel.this.m_ProcessOutput.monitor(fCmd, null, this.m_Process);
                    if (CNTKBrainscriptModel.this.m_ProcessOutput.getExitCode() != 0) {
                        CNTKBrainscriptModel.this.m_ExecutionFailure = new IllegalStateException("Exit code " + CNTKBrainscriptModel.this.m_ProcessOutput.getExitCode() + " when executing: " + fCmd);
                    }
                }
                catch (Exception e) {
                    CNTKBrainscriptModel.this.m_ExecutionFailure = new IllegalStateException("Failed to execute: " + fCmd, e);
                }
                CNTKBrainscriptModel.this.m_ProcessOutput = null;
                if (CNTKBrainscriptModel.this.m_TmpScript != null && CNTKBrainscriptModel.this.m_ExecutionFailure == null) {
                    CNTKBrainscriptModel.this.m_TmpScript.delete();
                }
            }

            public void stopExecution() {
                if (this.m_Process != null) {
                    this.m_Process.destroy();
                }
                super.stopExecution();
            }
        };
        new Thread((Runnable)this.m_Monitor).start();
        while (!this.m_Monitor.isStopped() && !this.m_Monitor.isRunning() && this.m_ExecutionFailure == null) {
            Utils.wait((LoggingObject)this.m_Monitor, (int)1000, (int)100);
        }
        while (this.m_Monitor.isRunning()) {
            Utils.wait((LoggingObject)this.m_Monitor, (int)1000, (int)100);
        }
        this.m_Monitor = null;
        if (this.m_ExecutionFailure != null) {
            throw this.m_ExecutionFailure;
        }
    }

    protected void initModel(Instances data) throws Exception {
        this.m_Wrapper.initDevice(this.m_DeviceType, this.m_GPUDeviceID);
        this.m_Wrapper.loadModel((File)this.m_Model);
        this.m_Wrapper.initModel(data.numClasses());
    }

    public void buildClassifier(Instances data) throws Exception {
        this.getCapabilities().testWithFail(data);
        data = new Instances(data);
        data.deleteWithMissingClass();
        this.buildModel(data);
        this.initModel(data);
    }

    public double[] distributionForInstance(Instance instance) throws Exception {
        int i;
        if (!this.m_Wrapper.isInitialized()) {
            this.initModel(instance.dataset());
        }
        TFloatArrayList values = new TFloatArrayList();
        for (i = 0; i < instance.numAttributes(); ++i) {
            values.add((float)instance.value(i));
        }
        float[] scores = this.m_Wrapper.predict(values.toArray());
        double[] result = new double[scores.length];
        if (scores.length > 1) {
            scores = this.m_Normalize.postProcessPrediction(scores);
            for (i = 0; i < scores.length; ++i) {
                result[i] = scores[i];
            }
        } else if (scores.length == 1) {
            result[0] = scores[0];
        }
        return result;
    }

    public String toString() {
        if (this.m_Wrapper.getModel() == null) {
            return "No model loaded yet!";
        }
        return this.m_Wrapper.getModel().toString();
    }
}

