package adams.ml.cntk;

import adams.core.License;
import adams.core.annotation.MixedCopyright;
import adams.core.base.BaseString;
import com.microsoft.CNTK.FloatVector;
import com.microsoft.CNTK.FloatVectorVector;
import com.microsoft.CNTK.UnorderedMapVariableValuePtr;
import com.microsoft.CNTK.Value;
import com.microsoft.CNTK.Variable;
import gnu.trove.set.hash.TIntHashSet;
import java.util.HashMap;
import java.util.Iterator;

@MixedCopyright(author = "CNTK", copyright = "Microsoft", license = License.MIT, url = "https://github.com/Microsoft/CNTK/blob/v2.0/Tests/EndToEndTests/EvalClientTests/JavaEvalTest/src/Main.java", note = "Original code based on CNTK example")
/* loaded from: input_file:adams/ml/cntk/CNTKPredictionWrapper.class */
public class CNTKPredictionWrapper extends AbstractCNTKModelWrapper {
    private static final long serialVersionUID = -1508684329565658944L;
    protected String m_ClassName;
    protected String m_ActualClassName;
    protected String m_OutputName;
    protected transient Variable m_OutputVar;
    protected int m_NumClasses;

    public String globalInfo() {
        return "Encapsulates a CNTK model for making predictions.";
    }

    @Override // adams.ml.cntk.AbstractCNTKModelWrapper
    public void defineOptions() {
        super.defineOptions();
        this.m_OptionManager.add("class-name", "className", getDefaultClassName());
        this.m_OptionManager.add("output-name", "outputName", getDefaultOutputName());
    }

    @Override // adams.ml.cntk.AbstractCNTKModelWrapper
    public void reset() {
        super.reset();
        this.m_NumClasses = -1;
    }

    protected String getDefaultClassName() {
        return "";
    }

    public void setClassName(String str) {
        this.m_ClassName = str;
        reset();
    }

    public String getClassName() {
        return this.m_ClassName;
    }

    public String classNameTipText() {
        return "The name of the class attribute in the model, in case it cannot be determined automatically.";
    }

    protected String getDefaultOutputName() {
        return "";
    }

    public void setOutputName(String str) {
        this.m_OutputName = str;
        reset();
    }

    public String getOutputName() {
        return this.m_OutputName;
    }

    public String outputNameTipText() {
        return "The name of the output variable in the model, in case it cannot be determined automatically based on its dimension.";
    }

    public int getNumClasses() {
        return this.m_NumClasses;
    }

    public void initModel(int i) throws Exception {
        Variable variable;
        super.initModel();
        this.m_NumClasses = i;
        this.m_OutputVar = null;
        Iterator it = this.m_Model.getOutputs().iterator();
        while (true) {
            if (!it.hasNext()) {
                break;
            }
            variable = (Variable) it.next();
            if (!this.m_OutputName.isEmpty()) {
                if (variable.getName().equals(this.m_OutputName) || variable.getUid().equals(this.m_OutputName)) {
                    break;
                }
            } else {
                if (variable.getShape().getTotalSize() == this.m_NumClasses) {
                    this.m_OutputVar = variable;
                    break;
                }
            }
        }
        this.m_OutputVar = variable;
        if (isLoggingEnabled()) {
            getLogger().info("Output var: " + this.m_OutputVar);
        }
        if (this.m_OutputVar == null) {
            throw new IllegalStateException("Failed to determine output variable!");
        }
        this.m_ActualClassName = this.m_ClassName;
        for (Variable variable2 : this.m_Model.getArguments()) {
            String name = variable2.getName();
            String uid = variable2.getUid();
            for (BaseString baseString : this.m_InputNames) {
                if (baseString.getValue().equals(name) || baseString.getValue().equals(uid)) {
                    if (this.m_ActualClassName.isEmpty() && variable2.getShape().getTotalSize() == this.m_NumClasses) {
                        this.m_ActualClassName = baseString.getValue();
                        if (isLoggingEnabled()) {
                            getLogger().info("Actual classname: " + this.m_ActualClassName);
                        }
                    }
                }
            }
        }
    }

    public float[] predict(float[] fArr) throws Exception {
        if (!isInitialized()) {
            throw new Exception("Model not initialized!");
        }
        if (this.m_Ranges == null) {
            this.m_Ranges = new HashMap();
            for (int i = 0; i < this.m_Inputs.length; i++) {
                this.m_Inputs[i].setMax(fArr.length + 1);
                this.m_Ranges.put(this.m_InputNames[i].getValue(), new TIntHashSet(this.m_Inputs[i].getIntIndices()));
            }
        }
        HashMap hashMap = new HashMap();
        if (this.m_Names.contains(this.m_ActualClassName)) {
            hashMap.put(this.m_ActualClassName, new FloatVector());
            for (int i2 = 0; i2 < this.m_NumClasses; i2++) {
                ((FloatVector) hashMap.get(this.m_ActualClassName)).add(0.0f);
            }
        }
        for (int i3 = 0; i3 < fArr.length; i3++) {
            Iterator<String> it = this.m_Names.iterator();
            while (true) {
                if (!it.hasNext()) {
                    break;
                }
                String next = it.next();
                TIntHashSet tIntHashSet = this.m_Ranges.get(next);
                if (tIntHashSet != null && tIntHashSet.contains(i3)) {
                    if (!hashMap.containsKey(next)) {
                        hashMap.put(next, new FloatVector());
                    }
                    if (!next.equals(this.m_ActualClassName)) {
                        ((FloatVector) hashMap.get(next)).add(fArr[i3]);
                    }
                }
            }
        }
        UnorderedMapVariableValuePtr unorderedMapVariableValuePtr = new UnorderedMapVariableValuePtr();
        for (String str : this.m_Names) {
            FloatVectorVector floatVectorVector = new FloatVectorVector();
            floatVectorVector.add((FloatVector) hashMap.get(str));
            unorderedMapVariableValuePtr.add(this.m_InputVars.get(str), Value.createDenseFloat(this.m_InputShapes.get(str), floatVectorVector, this.m_Device));
        }
        UnorderedMapVariableValuePtr unorderedMapVariableValuePtr2 = new UnorderedMapVariableValuePtr();
        unorderedMapVariableValuePtr2.add(this.m_OutputVar, (Value) null);
        this.m_Model.evaluate(unorderedMapVariableValuePtr, unorderedMapVariableValuePtr2, this.m_Device);
        FloatVectorVector floatVectorVector2 = new FloatVectorVector();
        unorderedMapVariableValuePtr2.getitem(this.m_OutputVar).copyVariableValueToFloat(this.m_OutputVar, floatVectorVector2);
        FloatVector floatVector = floatVectorVector2.get(0);
        float[] fArr2 = new float[(int) floatVector.size()];
        for (int i4 = 0; i4 < fArr2.length; i4++) {
            fArr2[i4] = floatVector.get(i4);
        }
        Iterator it2 = hashMap.values().iterator();
        while (it2.hasNext()) {
            ((FloatVector) it2.next()).delete();
        }
        floatVectorVector2.delete();
        unorderedMapVariableValuePtr.delete();
        unorderedMapVariableValuePtr2.delete();
        floatVector.delete();
        return fArr2;
    }

    @Override // adams.ml.cntk.AbstractCNTKModelWrapper
    public void cleanUp() {
        if (this.m_OutputVar != null) {
            this.m_OutputVar.delete();
            this.m_OutputVar = null;
        }
        super.cleanUp();
    }
}
