package weka.classifiers.functions;

import adams.core.Range;
import adams.core.Utils;
import adams.core.Variables;
import adams.core.base.BaseKeyValuePair;
import adams.core.base.BaseRegExp;
import adams.core.base.BaseString;
import adams.core.io.FileUtils;
import adams.core.io.PlaceholderDirectory;
import adams.core.io.PlaceholderFile;
import adams.core.io.TempUtils;
import adams.core.io.lister.LocalDirectoryLister;
import adams.core.logging.LoggingLevel;
import adams.flow.core.RunnableWithLogging;
import adams.ml.cntk.CNTK;
import adams.ml.cntk.CNTKPredictionWrapper;
import adams.ml.cntk.DeviceType;
import adams.ml.cntk.predictionpostprocessor.Normalize;
import com.github.fracpete.processoutput4j.core.StreamingProcessOutputType;
import com.github.fracpete.processoutput4j.core.StreamingProcessOwner;
import com.github.fracpete.processoutput4j.output.StreamingProcessOutput;
import gnu.trove.list.array.TFloatArrayList;
import java.io.File;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.List;
import java.util.Vector;
import weka.classifiers.AbstractClassifier;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.WekaOptionUtils;
import weka.core.converters.CNTKSaver;
import weka.core.converters.ConverterUtils;

/* loaded from: input_file:weka/classifiers/functions/CNTKBrainscriptModel.class */
public class CNTKBrainscriptModel extends AbstractClassifier implements StreamingProcessOwner {
    private static final long serialVersionUID = 7732345053235983381L;
    protected static String SCRIPT = "script";
    protected static String TRAINFILE = "train-file";
    protected static String MODELDIR = "model-directory";
    protected static String MODELEXT = "model-extensions";
    protected static String MODEL = "model";
    protected static String DEVICETYPE = "device-type";
    protected static String GPUDEVICEID = "gpu-device-id";
    protected static String INPUTS = "inputs";
    protected static String INPUTNAMES = "input-names";
    protected static String CLASSNAME = "class-name";
    protected static String OUTPUTNAME = "output-name";
    protected static String VARIABLES = "variables";
    protected PlaceholderFile m_TmpScript;
    protected LocalDirectoryLister m_Lister;
    protected StreamingProcessOutput m_ProcessOutput;
    protected RunnableWithLogging m_Monitor;
    protected IllegalStateException m_ExecutionFailure;
    protected PlaceholderFile m_Script = getDefaultScript();
    protected PlaceholderFile m_TrainFile = getDefaultTrainFile();
    protected PlaceholderDirectory m_ModelDirectory = getDefaultModelDirectory();
    protected PlaceholderFile m_Model = getDefaultModel();
    protected String m_ModelExtension = getDefaultModelExtension();
    protected DeviceType m_DeviceType = getDefaultDeviceType();
    protected long m_GPUDeviceID = getDefaultGPUDeviceID();
    protected BaseKeyValuePair[] m_Variables = getDefaultVariables();
    protected CNTKPredictionWrapper m_Wrapper = new CNTKPredictionWrapper();
    protected Normalize m_Normalize = new Normalize();

    public String globalInfo() {
        return "Builds a CNTK model using the supplied Brainscript and then applies the model to the data for making predictions.";
    }

    public Enumeration listOptions() {
        Vector vector = new Vector();
        WekaOptionUtils.addOption(vector, scriptTipText(), "" + getDefaultScript(), SCRIPT);
        WekaOptionUtils.addOption(vector, trainFileTipText(), "" + getDefaultTrainFile(), TRAINFILE);
        WekaOptionUtils.addOption(vector, modelDirectoryTipText(), "" + getDefaultModelDirectory(), MODELDIR);
        WekaOptionUtils.addOption(vector, modelExtensionTipText(), "" + getDefaultModelExtension(), MODELEXT);
        WekaOptionUtils.addOption(vector, modelTipText(), "" + getDefaultModel(), MODEL);
        WekaOptionUtils.addOption(vector, deviceTypeTipText(), "" + getDefaultDeviceType(), DEVICETYPE);
        WekaOptionUtils.addOption(vector, GPUDeviceIDTipText(), "" + getDefaultGPUDeviceID(), GPUDEVICEID);
        WekaOptionUtils.addOption(vector, inputsTipText(), Utils.arrayToString(getDefaultInputs()), INPUTS);
        WekaOptionUtils.addOption(vector, inputNamesTipText(), Utils.arrayToString(getDefaultInputNames()), INPUTNAMES);
        WekaOptionUtils.addOption(vector, classNameTipText(), getDefaultClassName(), CLASSNAME);
        WekaOptionUtils.addOption(vector, outputNameTipText(), getDefaultOutputName(), OUTPUTNAME);
        WekaOptionUtils.addOption(vector, variablesTipText(), Utils.arrayToString(getDefaultVariables()), VARIABLES);
        WekaOptionUtils.add(vector, super.listOptions());
        return WekaOptionUtils.toEnumeration(vector);
    }

    public void setOptions(String[] strArr) throws Exception {
        setScript(WekaOptionUtils.parse(strArr, SCRIPT, getDefaultScript()));
        setTrainFile(WekaOptionUtils.parse(strArr, TRAINFILE, getDefaultTrainFile()));
        setModelDirectory(WekaOptionUtils.parse(strArr, MODELDIR, getDefaultModelDirectory()));
        setModelExtension(WekaOptionUtils.parse(strArr, MODELEXT, getDefaultModelExtension()));
        setModel(WekaOptionUtils.parse(strArr, MODEL, getDefaultModel()));
        setDeviceType((DeviceType) WekaOptionUtils.parse(strArr, DEVICETYPE, getDefaultDeviceType()));
        setGPUDeviceID(WekaOptionUtils.parse(strArr, GPUDEVICEID, getDefaultGPUDeviceID()));
        setInputs(WekaOptionUtils.parse(strArr, INPUTS, getDefaultInputs()));
        setInputNames((BaseString[]) WekaOptionUtils.parse(strArr, INPUTNAMES, getDefaultInputNames()));
        setClassName(WekaOptionUtils.parse(strArr, CLASSNAME, getDefaultClassName()));
        setOutputName(WekaOptionUtils.parse(strArr, OUTPUTNAME, getDefaultOutputName()));
        setVariables((BaseKeyValuePair[]) WekaOptionUtils.parse(strArr, VARIABLES, getDefaultVariables()));
        super.setOptions(strArr);
    }

    public String[] getOptions() {
        ArrayList arrayList = new ArrayList();
        WekaOptionUtils.add(arrayList, SCRIPT, getScript());
        WekaOptionUtils.add(arrayList, TRAINFILE, getTrainFile());
        WekaOptionUtils.add(arrayList, MODELDIR, getModelDirectory());
        WekaOptionUtils.add(arrayList, MODELEXT, getModelExtension());
        WekaOptionUtils.add(arrayList, MODEL, getModel());
        WekaOptionUtils.add(arrayList, DEVICETYPE, getDeviceType());
        if (getDeviceType() == DeviceType.GPU) {
            WekaOptionUtils.add(arrayList, GPUDEVICEID, getGPUDeviceID());
        }
        WekaOptionUtils.add(arrayList, INPUTS, getInputs());
        WekaOptionUtils.add(arrayList, INPUTNAMES, getInputNames());
        WekaOptionUtils.add(arrayList, CLASSNAME, getClassName());
        WekaOptionUtils.add(arrayList, OUTPUTNAME, getOutputName());
        WekaOptionUtils.add(arrayList, VARIABLES, getVariables());
        WekaOptionUtils.add(arrayList, super.getOptions());
        return WekaOptionUtils.toArray(arrayList);
    }

    public void setDebug(boolean z) {
        super.setDebug(z);
        this.m_Wrapper.setLoggingLevel(z ? LoggingLevel.INFO : LoggingLevel.WARNING);
    }

    protected PlaceholderFile getDefaultScript() {
        return new PlaceholderFile();
    }

    public void setScript(PlaceholderFile placeholderFile) {
        this.m_Script = placeholderFile;
    }

    public PlaceholderFile getScript() {
        return this.m_Script;
    }

    public String scriptTipText() {
        return "The BrainScript to run.";
    }

    protected PlaceholderFile getDefaultTrainFile() {
        return new PlaceholderFile();
    }

    public void setTrainFile(PlaceholderFile placeholderFile) {
        this.m_TrainFile = placeholderFile;
    }

    public PlaceholderFile getTrainFile() {
        return this.m_TrainFile;
    }

    public String trainFileTipText() {
        return "The training file used by the Brainscript; the training instances will get saved to that file.";
    }

    protected PlaceholderDirectory getDefaultModelDirectory() {
        return new PlaceholderDirectory();
    }

    public void setModelDirectory(PlaceholderDirectory placeholderDirectory) {
        this.m_ModelDirectory = placeholderDirectory;
    }

    public PlaceholderDirectory getModelDirectory() {
        return this.m_ModelDirectory;
    }

    public String modelDirectoryTipText() {
        return "The directory containing the models, temp models and checkpoint files.";
    }

    protected PlaceholderFile getDefaultModel() {
        return new PlaceholderFile();
    }

    public void setModel(PlaceholderFile placeholderFile) {
        this.m_Model = placeholderFile;
    }

    public PlaceholderFile getModel() {
        return this.m_Model;
    }

    public String modelTipText() {
        return "The prebuilt CNTK model to use.";
    }

    protected String getDefaultModelExtension() {
        return ".cmf";
    }

    public void setModelExtension(String str) {
        this.m_ModelExtension = str;
    }

    public String getModelExtension() {
        return this.m_ModelExtension;
    }

    public String modelExtensionTipText() {
        return "The file extension used by the models (incl dot).";
    }

    protected DeviceType getDefaultDeviceType() {
        return DeviceType.DEFAULT;
    }

    public void setDeviceType(DeviceType deviceType) {
        this.m_DeviceType = deviceType;
    }

    public DeviceType getDeviceType() {
        return this.m_DeviceType;
    }

    public String deviceTypeTipText() {
        return "The device type to use.";
    }

    protected long getDefaultGPUDeviceID() {
        return 0L;
    }

    public void setGPUDeviceID(long j) {
        this.m_GPUDeviceID = j;
    }

    public long getGPUDeviceID() {
        return this.m_GPUDeviceID;
    }

    public String GPUDeviceIDTipText() {
        return "The GPU device ID.";
    }

    protected Range[] getDefaultInputs() {
        return new Range[0];
    }

    public void setInputs(Range[] rangeArr) {
        this.m_Wrapper.setInputs(rangeArr);
    }

    public Range[] getInputs() {
        return this.m_Wrapper.getInputs();
    }

    public String inputsTipText() {
        return this.m_Wrapper.inputsTipText();
    }

    protected BaseString[] getDefaultInputNames() {
        return new BaseString[0];
    }

    public void setInputNames(BaseString[] baseStringArr) {
        this.m_Wrapper.setInputNames(baseStringArr);
    }

    public BaseString[] getInputNames() {
        return this.m_Wrapper.getInputNames();
    }

    public String inputNamesTipText() {
        return this.m_Wrapper.inputNamesTipText();
    }

    protected String getDefaultClassName() {
        return "";
    }

    public void setClassName(String str) {
        this.m_Wrapper.setClassName(str);
    }

    public String getClassName() {
        return this.m_Wrapper.getClassName();
    }

    public String classNameTipText() {
        return this.m_Wrapper.classNameTipText();
    }

    protected String getDefaultOutputName() {
        return "";
    }

    public void setOutputName(String str) {
        this.m_Wrapper.setOutputName(str);
    }

    public String getOutputName() {
        return this.m_Wrapper.getOutputName();
    }

    public String outputNameTipText() {
        return this.m_Wrapper.outputNameTipText();
    }

    protected BaseKeyValuePair[] getDefaultVariables() {
        return new BaseKeyValuePair[0];
    }

    public void setVariables(BaseKeyValuePair[] baseKeyValuePairArr) {
        this.m_Variables = baseKeyValuePairArr;
    }

    public BaseKeyValuePair[] getVariables() {
        return this.m_Variables;
    }

    public String variablesTipText() {
        return "The key-value pairs representing variables and their associated values to be replaced in the script.";
    }

    public Capabilities getCapabilities() {
        Capabilities capabilities = super.getCapabilities();
        capabilities.disableAll();
        capabilities.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        capabilities.disableAllClasses();
        capabilities.enable(Capabilities.Capability.NOMINAL_CLASS);
        capabilities.enable(Capabilities.Capability.NUMERIC_CLASS);
        capabilities.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        return capabilities;
    }

    public StreamingProcessOutputType getOutputType() {
        return StreamingProcessOutputType.BOTH;
    }

    public void processOutput(String str, boolean z) {
        if (z) {
            System.out.println(str);
        } else {
            System.err.println(str);
        }
        if (this.m_Lister != null) {
            for (String str2 : this.m_Lister.list()) {
                FileUtils.delete(str2);
            }
        }
    }

    protected String preprocessScript() {
        String absolutePath = this.m_Script.getAbsolutePath();
        this.m_TmpScript = null;
        if (this.m_Variables.length > 0) {
            Variables variables = new Variables();
            for (BaseKeyValuePair baseKeyValuePair : this.m_Variables) {
                variables.set(baseKeyValuePair.getPairKey(), baseKeyValuePair.getPairValue());
            }
            List loadFromFile = FileUtils.loadFromFile(this.m_Script);
            if (loadFromFile != null) {
                absolutePath = TempUtils.createTempFile("adams-cntk-weka-bs-", ".bs").getAbsolutePath();
                String writeToFileMsg = FileUtils.writeToFileMsg(absolutePath, variables.expand(Utils.flatten(loadFromFile, "\n")), false, (String) null);
                if (writeToFileMsg != null) {
                    throw new IllegalStateException("Failed to write expanded script!\n" + writeToFileMsg);
                }
                this.m_TmpScript = new PlaceholderFile(absolutePath);
            }
        }
        return absolutePath;
    }

    protected void buildModel(Instances instances) throws Exception {
        getCapabilities().testWithFail(instances);
        if (getDebug()) {
            System.out.println("Saving data to: " + this.m_TrainFile);
        }
        CNTKSaver cNTKSaver = new CNTKSaver();
        cNTKSaver.setSuppressMissing(true);
        cNTKSaver.setUseSparseFormat(false);
        cNTKSaver.setInputs(getInputs());
        cNTKSaver.setInputNames(getInputNames());
        cNTKSaver.setDestination(this.m_TrainFile.getAbsoluteFile());
        ConverterUtils.DataSink.write(cNTKSaver, instances);
        if (getDebug()) {
            System.out.println("Saved data to: " + this.m_TrainFile);
        }
        final String str = CNTK.getBinary().getAbsolutePath() + " configFile=" + preprocessScript();
        if (getDebug()) {
            System.out.println("Command: " + str);
        }
        this.m_Lister = new LocalDirectoryLister();
        this.m_Lister.setWatchDir(this.m_ModelDirectory.getAbsolutePath());
        this.m_Lister.setRegExp(new BaseRegExp(".*\\" + this.m_ModelExtension + "\\.([0-9]+|ckp)$"));
        this.m_Lister.setRecursive(false);
        this.m_Lister.setListDirs(false);
        this.m_Lister.setListFiles(true);
        this.m_ExecutionFailure = null;
        this.m_ProcessOutput = new StreamingProcessOutput(this);
        this.m_Monitor = new RunnableWithLogging() { // from class: weka.classifiers.functions.CNTKBrainscriptModel.1
            private static final long serialVersionUID = -4475355379511760429L;
            protected Process m_Process;

            protected void doRun() {
                try {
                    this.m_Process = Runtime.getRuntime().exec(str, (String[]) null, (File) null);
                    CNTKBrainscriptModel.this.m_ProcessOutput.monitor(str, (String[]) null, this.m_Process);
                    if (CNTKBrainscriptModel.this.m_ProcessOutput.getExitCode() != 0) {
                        CNTKBrainscriptModel.this.m_ExecutionFailure = new IllegalStateException("Exit code " + CNTKBrainscriptModel.this.m_ProcessOutput.getExitCode() + " when executing: " + str);
                    }
                } catch (Exception e) {
                    CNTKBrainscriptModel.this.m_ExecutionFailure = new IllegalStateException("Failed to execute: " + str, e);
                }
                CNTKBrainscriptModel.this.m_ProcessOutput = null;
                if (CNTKBrainscriptModel.this.m_TmpScript == null || CNTKBrainscriptModel.this.m_ExecutionFailure != null) {
                    return;
                }
                CNTKBrainscriptModel.this.m_TmpScript.delete();
            }

            public void stopExecution() {
                if (this.m_Process != null) {
                    this.m_Process.destroy();
                }
                super.stopExecution();
            }
        };
        new Thread((Runnable) this.m_Monitor).start();
        while (!this.m_Monitor.isStopped() && !this.m_Monitor.isRunning() && this.m_ExecutionFailure == null) {
            Utils.wait(this.m_Monitor, 1000, 100);
        }
        while (this.m_Monitor.isRunning()) {
            Utils.wait(this.m_Monitor, 1000, 100);
        }
        this.m_Monitor = null;
        if (this.m_ExecutionFailure != null) {
            throw this.m_ExecutionFailure;
        }
    }

    protected void initModel(Instances instances) throws Exception {
        this.m_Wrapper.initDevice(this.m_DeviceType, this.m_GPUDeviceID);
        this.m_Wrapper.loadModel(this.m_Model);
        this.m_Wrapper.initModel(instances.numClasses());
    }

    public void buildClassifier(Instances instances) throws Exception {
        getCapabilities().testWithFail(instances);
        Instances instances2 = new Instances(instances);
        instances2.deleteWithMissingClass();
        buildModel(instances2);
        initModel(instances2);
    }

    public double[] distributionForInstance(Instance instance) throws Exception {
        if (!this.m_Wrapper.isInitialized()) {
            initModel(instance.dataset());
        }
        TFloatArrayList tFloatArrayList = new TFloatArrayList();
        for (int i = 0; i < instance.numAttributes(); i++) {
            tFloatArrayList.add((float) instance.value(i));
        }
        float[] predict = this.m_Wrapper.predict(tFloatArrayList.toArray());
        double[] dArr = new double[predict.length];
        if (predict.length > 1) {
            float[] postProcessPrediction = this.m_Normalize.postProcessPrediction(predict);
            for (int i2 = 0; i2 < postProcessPrediction.length; i2++) {
                dArr[i2] = postProcessPrediction[i2];
            }
        } else if (predict.length == 1) {
            dArr[0] = predict[0];
        }
        return dArr;
    }

    public String toString() {
        return this.m_Wrapper.getModel() == null ? "No model loaded yet!" : this.m_Wrapper.getModel().toString();
    }
}
