/*
 * Decompiled with CFR 0.152.
 */
package adams.ml.cntk;

import adams.core.CleanUpHandler;
import adams.core.License;
import adams.core.Range;
import adams.core.Utils;
import adams.core.annotation.MixedCopyright;
import adams.core.base.BaseString;
import adams.core.option.AbstractOptionHandler;
import adams.ml.cntk.DeviceType;
import com.microsoft.CNTK.DeviceDescriptor;
import com.microsoft.CNTK.Function;
import com.microsoft.CNTK.NDShape;
import com.microsoft.CNTK.Variable;
import gnu.trove.set.hash.TIntHashSet;
import java.io.File;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

@MixedCopyright(author="CNTK", copyright="Microsoft", license=License.MIT, url="https://github.com/Microsoft/CNTK/blob/v2.0/Tests/EndToEndTests/EvalClientTests/JavaEvalTest/src/Main.java", note="Original code based on CNTK example")
public abstract class AbstractCNTKModelWrapper
extends AbstractOptionHandler
implements CleanUpHandler {
    private static final long serialVersionUID = -1508684329565658944L;
    protected transient Function m_Model;
    protected transient DeviceDescriptor m_Device;
    protected Range[] m_Inputs;
    protected BaseString[] m_InputNames;
    protected transient Map<String, Variable> m_InputVars;
    protected transient Map<String, NDShape> m_InputShapes;
    protected transient List<String> m_Names;
    protected transient Map<String, TIntHashSet> m_Ranges;
    protected transient boolean m_Initialized;

    public void defineOptions() {
        super.defineOptions();
        this.m_OptionManager.add("input", "inputs", (Object)this.getDefaultInputs());
        this.m_OptionManager.add("input-name", "inputNames", (Object)this.getDefaultInputNames());
    }

    public void reset() {
        super.reset();
        this.cleanUp();
        this.m_Initialized = false;
    }

    public void setModel(Function value) {
        this.m_Model = value;
        this.reset();
    }

    public Function getModel() {
        return this.m_Model;
    }

    public void setDevice(DeviceDescriptor value) {
        this.m_Device = value;
        this.reset();
    }

    public DeviceDescriptor getDevice() {
        return this.m_Device;
    }

    protected Range[] getDefaultInputs() {
        return new Range[0];
    }

    public void setInputs(Range[] value) {
        this.m_Inputs = value;
        this.m_InputNames = (BaseString[])Utils.adjustArray((Object)this.m_InputNames, (int)this.m_Inputs.length, (Object)new BaseString());
        this.reset();
    }

    public Range[] getInputs() {
        return this.m_Inputs;
    }

    public String inputsTipText() {
        return "The column ranges determining the inputs (eg for 'features' and 'class').";
    }

    protected BaseString[] getDefaultInputNames() {
        return new BaseString[0];
    }

    public void setInputNames(BaseString[] value) {
        this.m_InputNames = value;
        this.m_Inputs = (Range[])Utils.adjustArray((Object)this.m_Inputs, (int)this.m_InputNames.length, (Object)new Range());
        this.reset();
    }

    public BaseString[] getInputNames() {
        return this.m_InputNames;
    }

    public String inputNamesTipText() {
        return "The names of the inputs (eg 'features' and 'class').";
    }

    public boolean isInitialized() {
        return this.m_Initialized;
    }

    public void initDevice(DeviceType device, long gpu) {
        switch (device) {
            case DEFAULT: {
                this.m_Device = DeviceDescriptor.useDefaultDevice();
                break;
            }
            case CPU: {
                this.m_Device = DeviceDescriptor.getCPUDevice();
                break;
            }
            case GPU: {
                this.m_Device = DeviceDescriptor.getGPUDevice((long)gpu);
                break;
            }
            default: {
                throw new IllegalStateException("Unhandled device type: " + (Object)((Object)device));
            }
        }
    }

    public void loadModel(File model) throws Exception {
        if (!model.exists()) {
            throw new IllegalStateException("Model does not exist: " + model);
        }
        if (model.isDirectory()) {
            throw new IllegalStateException("Model points to directory: " + model);
        }
        this.m_Model = Function.load((String)model.getAbsolutePath(), (DeviceDescriptor)this.m_Device);
        if (this.m_Model == null) {
            throw new IllegalStateException("Failed to load model: " + model);
        }
    }

    protected void initModel() throws Exception {
        if (this.m_Model == null) {
            throw new IllegalStateException("No model present!");
        }
        if (this.m_Device == null) {
            throw new IllegalStateException("No device present!");
        }
        if (this.isLoggingEnabled()) {
            int i;
            this.getLogger().info("Arguments:");
            for (i = 0; i < this.m_Model.getArguments().size(); ++i) {
                this.getLogger().info("- " + this.m_Model.getArguments().get(i));
            }
            this.getLogger().info("Outputs:");
            for (i = 0; i < this.m_Model.getOutputs().size(); ++i) {
                this.getLogger().info("- " + this.m_Model.getOutputs().get(i));
            }
        }
        this.m_InputVars = new HashMap<String, Variable>();
        this.m_InputShapes = new HashMap<String, NDShape>();
        this.m_Names = new ArrayList<String>();
        block2: for (Variable var : this.m_Model.getArguments()) {
            String name = var.getName();
            String uid = var.getUid();
            for (BaseString inputName : this.m_InputNames) {
                if (!inputName.getValue().equals(name) && !inputName.getValue().equals(uid)) continue;
                this.m_Names.add(inputName.getValue());
                this.m_InputVars.put(inputName.getValue(), var);
                this.m_InputShapes.put(inputName.getValue(), var.getShape());
                if (!this.isLoggingEnabled()) continue block2;
                this.getLogger().info("Input var '" + inputName.getValue() + "': " + var);
                continue block2;
            }
        }
        this.m_Ranges = null;
        this.m_Initialized = true;
    }

    public void cleanUp() {
        if (this.m_InputVars != null) {
            for (Variable variable : this.m_InputVars.values()) {
                variable.delete();
            }
            this.m_InputVars.clear();
            this.m_InputVars = null;
        }
        if (this.m_InputShapes != null) {
            for (NDShape nDShape : this.m_InputShapes.values()) {
                nDShape.delete();
            }
            this.m_InputShapes.clear();
            this.m_InputShapes = null;
        }
        if (this.m_Ranges != null) {
            this.m_Ranges.clear();
            this.m_Ranges = null;
        }
    }
}

