/*
 * Decompiled with CFR 0.152.
 */
package adams.ml.cntk;

import adams.core.License;
import adams.core.annotation.MixedCopyright;
import adams.core.base.BaseString;
import adams.ml.cntk.AbstractCNTKModelWrapper;
import com.microsoft.CNTK.DeviceDescriptor;
import com.microsoft.CNTK.FloatVector;
import com.microsoft.CNTK.FloatVectorVector;
import com.microsoft.CNTK.NDShape;
import com.microsoft.CNTK.UnorderedMapVariableValuePtr;
import com.microsoft.CNTK.Value;
import com.microsoft.CNTK.Variable;
import gnu.trove.set.hash.TIntHashSet;
import java.util.HashMap;

@MixedCopyright(author="CNTK", copyright="Microsoft", license=License.MIT, url="https://github.com/Microsoft/CNTK/blob/v2.0/Tests/EndToEndTests/EvalClientTests/JavaEvalTest/src/Main.java", note="Original code based on CNTK example")
public class CNTKPredictionWrapper
extends AbstractCNTKModelWrapper {
    private static final long serialVersionUID = -1508684329565658944L;
    protected String m_ClassName;
    protected String m_ActualClassName;
    protected String m_OutputName;
    protected transient Variable m_OutputVar;
    protected int m_NumClasses;

    public String globalInfo() {
        return "Encapsulates a CNTK model for making predictions.";
    }

    @Override
    public void defineOptions() {
        super.defineOptions();
        this.m_OptionManager.add("class-name", "className", (Object)this.getDefaultClassName());
        this.m_OptionManager.add("output-name", "outputName", (Object)this.getDefaultOutputName());
    }

    @Override
    public void reset() {
        super.reset();
        this.m_NumClasses = -1;
    }

    protected String getDefaultClassName() {
        return "";
    }

    public void setClassName(String value) {
        this.m_ClassName = value;
        this.reset();
    }

    public String getClassName() {
        return this.m_ClassName;
    }

    public String classNameTipText() {
        return "The name of the class attribute in the model, in case it cannot be determined automatically.";
    }

    protected String getDefaultOutputName() {
        return "";
    }

    public void setOutputName(String value) {
        this.m_OutputName = value;
        this.reset();
    }

    public String getOutputName() {
        return this.m_OutputName;
    }

    public String outputNameTipText() {
        return "The name of the output variable in the model, in case it cannot be determined automatically based on its dimension.";
    }

    public int getNumClasses() {
        return this.m_NumClasses;
    }

    public void initModel(int numClasses) throws Exception {
        super.initModel();
        this.m_NumClasses = numClasses;
        this.m_OutputVar = null;
        for (Variable var : this.m_Model.getOutputs()) {
            if (this.m_OutputName.isEmpty()) {
                if (var.getShape().getTotalSize() != (long)this.m_NumClasses) continue;
                this.m_OutputVar = var;
                break;
            }
            if (!var.getName().equals(this.m_OutputName) && !var.getUid().equals(this.m_OutputName)) continue;
            this.m_OutputVar = var;
            break;
        }
        if (this.isLoggingEnabled()) {
            this.getLogger().info("Output var: " + this.m_OutputVar);
        }
        if (this.m_OutputVar == null) {
            throw new IllegalStateException("Failed to determine output variable!");
        }
        this.m_ActualClassName = this.m_ClassName;
        block1: for (Variable var : this.m_Model.getArguments()) {
            String name = var.getName();
            String uid = var.getUid();
            for (BaseString inputName : this.m_InputNames) {
                if (!inputName.getValue().equals(name) && !inputName.getValue().equals(uid)) continue;
                if (!this.m_ActualClassName.isEmpty() || var.getShape().getTotalSize() != (long)this.m_NumClasses) continue block1;
                this.m_ActualClassName = inputName.getValue();
                if (!this.isLoggingEnabled()) continue block1;
                this.getLogger().info("Actual classname: " + this.m_ActualClassName);
                continue block1;
            }
        }
    }

    public float[] predict(float[] input) throws Exception {
        int i;
        if (!this.isInitialized()) {
            throw new Exception("Model not initialized!");
        }
        if (this.m_Ranges == null) {
            this.m_Ranges = new HashMap();
            for (i = 0; i < this.m_Inputs.length; ++i) {
                this.m_Inputs[i].setMax(input.length + 1);
                this.m_Ranges.put(this.m_InputNames[i].getValue(), new TIntHashSet(this.m_Inputs[i].getIntIndices()));
            }
        }
        HashMap<String, FloatVector> floatVecs = new HashMap<String, FloatVector>();
        if (this.m_Names.contains(this.m_ActualClassName)) {
            floatVecs.put(this.m_ActualClassName, new FloatVector());
            for (i = 0; i < this.m_NumClasses; ++i) {
                ((FloatVector)floatVecs.get(this.m_ActualClassName)).add(0.0f);
            }
        }
        block2: for (i = 0; i < input.length; ++i) {
            for (String name : this.m_Names) {
                TIntHashSet range = (TIntHashSet)this.m_Ranges.get(name);
                if (range == null || !range.contains(i)) continue;
                if (!floatVecs.containsKey(name)) {
                    floatVecs.put(name, new FloatVector());
                }
                if (name.equals(this.m_ActualClassName)) continue block2;
                ((FloatVector)floatVecs.get(name)).add(input[i]);
                continue block2;
            }
        }
        UnorderedMapVariableValuePtr inputDataMap = new UnorderedMapVariableValuePtr();
        for (String name : this.m_Names) {
            FloatVectorVector floatVecVec = new FloatVectorVector();
            floatVecVec.add((FloatVector)floatVecs.get(name));
            Value inputVal = Value.createDenseFloat((NDShape)((NDShape)this.m_InputShapes.get(name)), (FloatVectorVector)floatVecVec, (DeviceDescriptor)this.m_Device);
            inputDataMap.add((Variable)this.m_InputVars.get(name), inputVal);
        }
        UnorderedMapVariableValuePtr outputDataMap = new UnorderedMapVariableValuePtr();
        outputDataMap.add(this.m_OutputVar, null);
        this.m_Model.evaluate(inputDataMap, outputDataMap, this.m_Device);
        FloatVectorVector outputBuffer = new FloatVectorVector();
        outputDataMap.getitem(this.m_OutputVar).copyVariableValueToFloat(this.m_OutputVar, outputBuffer);
        FloatVector results = outputBuffer.get(0);
        float[] result = new float[(int)results.size()];
        for (i = 0; i < result.length; ++i) {
            result[i] = results.get(i);
        }
        for (FloatVector floatVec : floatVecs.values()) {
            floatVec.delete();
        }
        outputBuffer.delete();
        inputDataMap.delete();
        outputDataMap.delete();
        results.delete();
        return result;
    }

    @Override
    public void cleanUp() {
        if (this.m_OutputVar != null) {
            this.m_OutputVar.delete();
            this.m_OutputVar = null;
        }
        super.cleanUp();
    }
}

