package adams.ml.cntk;

import adams.core.License;
import adams.core.annotation.MixedCopyright;
import com.microsoft.CNTK.FloatVector;
import com.microsoft.CNTK.FloatVectorVector;
import com.microsoft.CNTK.Function;
import com.microsoft.CNTK.UnorderedMapVariableValuePtr;
import com.microsoft.CNTK.Value;
import com.microsoft.CNTK.Variable;
import gnu.trove.set.hash.TIntHashSet;
import java.util.HashMap;
import java.util.Iterator;

@MixedCopyright(author = "CNTK", copyright = "Microsoft", license = License.MIT, url = "https://github.com/Microsoft/CNTK/blob/v2.0/Tests/EndToEndTests/EvalClientTests/JavaEvalTest/src/Main.java", note = "Original code based on CNTK example")
/* loaded from: input_file:adams/ml/cntk/CNTKFilterWrapper.class */
public class CNTKFilterWrapper extends AbstractCNTKModelWrapper {
    private static final long serialVersionUID = -1508684329565658944L;
    protected String m_FilterLayer;
    protected transient Function m_Layer;
    protected transient Variable m_OutputVar;

    public String globalInfo() {
        return "Encapsulates a CNTK model for filtering data.";
    }

    @Override // adams.ml.cntk.AbstractCNTKModelWrapper
    public void defineOptions() {
        super.defineOptions();
        this.m_OptionManager.add("output-layer", "filterLayer", getDefaultFilterLayer());
    }

    protected String getDefaultFilterLayer() {
        return "";
    }

    public void setFilterLayer(String str) {
        this.m_FilterLayer = str;
        reset();
    }

    public String getFilterLayer() {
        return this.m_FilterLayer;
    }

    public String filterLayerTipText() {
        return "The name of the layer to obtain the filtered data from.";
    }

    @Override // adams.ml.cntk.AbstractCNTKModelWrapper
    public void initModel() throws Exception {
        if (this.m_FilterLayer.isEmpty()) {
            throw new IllegalStateException("No filter layer defined!");
        }
        super.initModel();
        this.m_Layer = this.m_Model.findByName(this.m_FilterLayer);
        if (this.m_Layer.getOutputs().size() != 1) {
            throw new IllegalStateException("More than one output in layer '" + this.m_FilterLayer + "'!");
        }
        this.m_OutputVar = (Variable) this.m_Layer.getOutputs().get(0);
    }

    public float[] filter(float[] fArr) throws Exception {
        if (!isInitialized()) {
            throw new Exception("Model not initialized!");
        }
        if (this.m_Ranges == null) {
            this.m_Ranges = new HashMap();
            for (int i = 0; i < this.m_Inputs.length; i++) {
                this.m_Inputs[i].setMax(fArr.length + 1);
                this.m_Ranges.put(this.m_InputNames[i].getValue(), new TIntHashSet(this.m_Inputs[i].getIntIndices()));
            }
        }
        HashMap hashMap = new HashMap();
        for (int i2 = 0; i2 < fArr.length; i2++) {
            Iterator<String> it = this.m_Names.iterator();
            while (true) {
                if (!it.hasNext()) {
                    break;
                }
                String next = it.next();
                TIntHashSet tIntHashSet = this.m_Ranges.get(next);
                if (tIntHashSet != null && tIntHashSet.contains(i2)) {
                    if (!hashMap.containsKey(next)) {
                        hashMap.put(next, new FloatVector());
                    }
                    ((FloatVector) hashMap.get(next)).add(fArr[i2]);
                }
            }
        }
        UnorderedMapVariableValuePtr unorderedMapVariableValuePtr = new UnorderedMapVariableValuePtr();
        for (String str : this.m_Names) {
            FloatVectorVector floatVectorVector = new FloatVectorVector();
            floatVectorVector.add((FloatVector) hashMap.get(str));
            unorderedMapVariableValuePtr.add(this.m_InputVars.get(str), Value.createDenseFloat(this.m_InputShapes.get(str), floatVectorVector, this.m_Device));
        }
        UnorderedMapVariableValuePtr unorderedMapVariableValuePtr2 = new UnorderedMapVariableValuePtr();
        unorderedMapVariableValuePtr2.add(this.m_OutputVar, (Value) null);
        this.m_Model.evaluate(unorderedMapVariableValuePtr, unorderedMapVariableValuePtr2, this.m_Device);
        FloatVectorVector floatVectorVector2 = new FloatVectorVector();
        unorderedMapVariableValuePtr2.getitem(this.m_OutputVar).copyVariableValueToFloat(this.m_OutputVar, floatVectorVector2);
        FloatVector floatVector = floatVectorVector2.get(0);
        float[] fArr2 = new float[(int) floatVector.size()];
        for (int i3 = 0; i3 < fArr2.length; i3++) {
            fArr2[i3] = floatVector.get(i3);
        }
        Iterator it2 = hashMap.values().iterator();
        while (it2.hasNext()) {
            ((FloatVector) it2.next()).delete();
        }
        floatVectorVector2.delete();
        unorderedMapVariableValuePtr.delete();
        unorderedMapVariableValuePtr2.delete();
        floatVector.delete();
        return fArr2;
    }
}
