/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.functions;

import adams.core.Range;
import adams.core.Utils;
import adams.core.base.BaseObject;
import adams.core.base.BaseString;
import adams.core.io.PlaceholderFile;
import adams.core.logging.LoggingLevel;
import adams.ml.cntk.CNTKPredictionWrapper;
import adams.ml.cntk.DeviceType;
import adams.ml.cntk.predictionpostprocessor.Normalize;
import gnu.trove.list.array.TFloatArrayList;
import java.io.File;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.Vector;
import weka.classifiers.AbstractClassifier;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.WekaOptionUtils;

public class CNTKPrebuiltModel
extends AbstractClassifier {
    private static final long serialVersionUID = 7732345053235983381L;
    protected static String MODEL = "model";
    protected static String DEVICETYPE = "device-type";
    protected static String GPUDEVICEID = "gpu-device-id";
    protected static String INPUTS = "inputs";
    protected static String INPUTNAMES = "input-names";
    protected static String CLASSNAME = "class-name";
    protected static String OUTPUTNAME = "output-name";
    protected PlaceholderFile m_Model = this.getDefaultModel();
    protected DeviceType m_DeviceType = this.getDefaultDeviceType();
    protected long m_GPUDeviceID = this.getDefaultGPUDeviceID();
    protected CNTKPredictionWrapper m_Wrapper = new CNTKPredictionWrapper();
    protected Normalize m_Normalize = new Normalize();

    public String globalInfo() {
        return "Applies the pre-built CNTK model to the data for making predictions.";
    }

    public Enumeration listOptions() {
        Vector result = new Vector();
        WekaOptionUtils.addOption(result, (String)this.modelTipText(), (String)("" + this.getDefaultModel()), (String)MODEL);
        WekaOptionUtils.addOption(result, (String)this.deviceTypeTipText(), (String)("" + this.getDefaultDeviceType()), (String)DEVICETYPE);
        WekaOptionUtils.addOption(result, (String)this.GPUDeviceIDTipText(), (String)("" + this.getDefaultGPUDeviceID()), (String)GPUDEVICEID);
        WekaOptionUtils.addOption(result, (String)this.inputsTipText(), (String)Utils.arrayToString((Object)this.getDefaultInputs()), (String)INPUTS);
        WekaOptionUtils.addOption(result, (String)this.inputNamesTipText(), (String)Utils.arrayToString((Object)this.getDefaultInputNames()), (String)INPUTNAMES);
        WekaOptionUtils.addOption(result, (String)this.classNameTipText(), (String)this.getDefaultClassName(), (String)CLASSNAME);
        WekaOptionUtils.addOption(result, (String)this.outputNameTipText(), (String)this.getDefaultOutputName(), (String)OUTPUTNAME);
        WekaOptionUtils.add(result, (Enumeration)super.listOptions());
        return WekaOptionUtils.toEnumeration(result);
    }

    public void setOptions(String[] options) throws Exception {
        this.setModel(WekaOptionUtils.parse((String[])options, (String)MODEL, (PlaceholderFile)this.getDefaultModel()));
        this.setDeviceType((DeviceType)WekaOptionUtils.parse((String[])options, (String)DEVICETYPE, (Enum)this.getDefaultDeviceType()));
        this.setGPUDeviceID(WekaOptionUtils.parse((String[])options, (String)GPUDEVICEID, (long)this.getDefaultGPUDeviceID()));
        this.setInputs(WekaOptionUtils.parse((String[])options, (String)INPUTS, (Range[])this.getDefaultInputs()));
        this.setInputNames((BaseString[])WekaOptionUtils.parse((String[])options, (String)INPUTNAMES, (BaseObject[])this.getDefaultInputNames()));
        this.setClassName(WekaOptionUtils.parse((String[])options, (String)CLASSNAME, (String)this.getDefaultClassName()));
        this.setOutputName(WekaOptionUtils.parse((String[])options, (String)OUTPUTNAME, (String)this.getDefaultOutputName()));
        super.setOptions(options);
    }

    public String[] getOptions() {
        ArrayList result = new ArrayList();
        WekaOptionUtils.add(result, (String)MODEL, (File)this.getModel());
        WekaOptionUtils.add(result, (String)DEVICETYPE, (Enum)this.getDeviceType());
        if (this.getDeviceType() == DeviceType.GPU) {
            WekaOptionUtils.add(result, (String)GPUDEVICEID, (long)this.getGPUDeviceID());
        }
        WekaOptionUtils.add(result, (String)INPUTS, (Range[])this.getInputs());
        WekaOptionUtils.add(result, (String)INPUTNAMES, (BaseObject[])this.getInputNames());
        WekaOptionUtils.add(result, (String)CLASSNAME, (String)this.getClassName());
        WekaOptionUtils.add(result, (String)OUTPUTNAME, (String)this.getOutputName());
        WekaOptionUtils.add(result, (String[])super.getOptions());
        return WekaOptionUtils.toArray(result);
    }

    public void setDebug(boolean debug) {
        super.setDebug(debug);
        this.m_Wrapper.setLoggingLevel(debug ? LoggingLevel.INFO : LoggingLevel.WARNING);
    }

    protected PlaceholderFile getDefaultModel() {
        return new PlaceholderFile();
    }

    public void setModel(PlaceholderFile value) {
        this.m_Model = value;
    }

    public PlaceholderFile getModel() {
        return this.m_Model;
    }

    public String modelTipText() {
        return "The prebuilt CNTK model to use.";
    }

    protected DeviceType getDefaultDeviceType() {
        return DeviceType.DEFAULT;
    }

    public void setDeviceType(DeviceType value) {
        this.m_DeviceType = value;
    }

    public DeviceType getDeviceType() {
        return this.m_DeviceType;
    }

    public String deviceTypeTipText() {
        return "The device type to use.";
    }

    protected long getDefaultGPUDeviceID() {
        return 0L;
    }

    public void setGPUDeviceID(long value) {
        this.m_GPUDeviceID = value;
    }

    public long getGPUDeviceID() {
        return this.m_GPUDeviceID;
    }

    public String GPUDeviceIDTipText() {
        return "The GPU device ID.";
    }

    protected Range[] getDefaultInputs() {
        return new Range[0];
    }

    public void setInputs(Range[] value) {
        this.m_Wrapper.setInputs(value);
    }

    public Range[] getInputs() {
        return this.m_Wrapper.getInputs();
    }

    public String inputsTipText() {
        return this.m_Wrapper.inputsTipText();
    }

    protected BaseString[] getDefaultInputNames() {
        return new BaseString[0];
    }

    public void setInputNames(BaseString[] value) {
        this.m_Wrapper.setInputNames(value);
    }

    public BaseString[] getInputNames() {
        return this.m_Wrapper.getInputNames();
    }

    public String inputNamesTipText() {
        return this.m_Wrapper.inputNamesTipText();
    }

    protected String getDefaultClassName() {
        return "";
    }

    public void setClassName(String value) {
        this.m_Wrapper.setClassName(value);
    }

    public String getClassName() {
        return this.m_Wrapper.getClassName();
    }

    public String classNameTipText() {
        return this.m_Wrapper.classNameTipText();
    }

    protected String getDefaultOutputName() {
        return "";
    }

    public void setOutputName(String value) {
        this.m_Wrapper.setOutputName(value);
    }

    public String getOutputName() {
        return this.m_Wrapper.getOutputName();
    }

    public String outputNameTipText() {
        return this.m_Wrapper.outputNameTipText();
    }

    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAll();
        result.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        result.disableAllClasses();
        result.enable(Capabilities.Capability.NOMINAL_CLASS);
        result.enable(Capabilities.Capability.NUMERIC_CLASS);
        result.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        return result;
    }

    protected void initModel(Instances data) throws Exception {
        this.m_Wrapper.initDevice(this.m_DeviceType, this.m_GPUDeviceID);
        this.m_Wrapper.loadModel((File)this.m_Model);
        this.m_Wrapper.initModel(data.numClasses());
    }

    public void buildClassifier(Instances data) throws Exception {
        this.getCapabilities().testWithFail(data);
        data = new Instances(data);
        data.deleteWithMissingClass();
        this.initModel(data);
    }

    public double[] distributionForInstance(Instance instance) throws Exception {
        int i;
        if (!this.m_Wrapper.isInitialized()) {
            this.initModel(instance.dataset());
        }
        TFloatArrayList values = new TFloatArrayList();
        for (i = 0; i < instance.numAttributes(); ++i) {
            values.add((float)instance.value(i));
        }
        float[] scores = this.m_Wrapper.predict(values.toArray());
        double[] result = new double[scores.length];
        if (scores.length > 1) {
            scores = this.m_Normalize.postProcessPrediction(scores);
            for (i = 0; i < scores.length; ++i) {
                result[i] = scores[i];
            }
        } else if (scores.length == 1) {
            result[0] = scores[0];
        }
        return result;
    }

    public String toString() {
        if (this.m_Wrapper.getModel() == null) {
            return "No model loaded yet!";
        }
        return this.m_Wrapper.getModel().toString();
    }
}

