/*
 * Decompiled with CFR 0.152.
 */
package adams.ml.cntk;

import adams.core.License;
import adams.core.annotation.MixedCopyright;
import adams.ml.cntk.AbstractCNTKModelWrapper;
import com.microsoft.CNTK.DeviceDescriptor;
import com.microsoft.CNTK.FloatVector;
import com.microsoft.CNTK.FloatVectorVector;
import com.microsoft.CNTK.Function;
import com.microsoft.CNTK.NDShape;
import com.microsoft.CNTK.UnorderedMapVariableValuePtr;
import com.microsoft.CNTK.Value;
import com.microsoft.CNTK.Variable;
import gnu.trove.set.hash.TIntHashSet;
import java.util.HashMap;

@MixedCopyright(author="CNTK", copyright="Microsoft", license=License.MIT, url="https://github.com/Microsoft/CNTK/blob/v2.0/Tests/EndToEndTests/EvalClientTests/JavaEvalTest/src/Main.java", note="Original code based on CNTK example")
public class CNTKFilterWrapper
extends AbstractCNTKModelWrapper {
    private static final long serialVersionUID = -1508684329565658944L;
    protected String m_FilterLayer;
    protected transient Function m_Layer;
    protected transient Variable m_OutputVar;

    public String globalInfo() {
        return "Encapsulates a CNTK model for filtering data.";
    }

    @Override
    public void defineOptions() {
        super.defineOptions();
        this.m_OptionManager.add("output-layer", "filterLayer", (Object)this.getDefaultFilterLayer());
    }

    protected String getDefaultFilterLayer() {
        return "";
    }

    public void setFilterLayer(String value) {
        this.m_FilterLayer = value;
        this.reset();
    }

    public String getFilterLayer() {
        return this.m_FilterLayer;
    }

    public String filterLayerTipText() {
        return "The name of the layer to obtain the filtered data from.";
    }

    @Override
    public void initModel() throws Exception {
        if (this.m_FilterLayer.isEmpty()) {
            throw new IllegalStateException("No filter layer defined!");
        }
        super.initModel();
        this.m_Layer = this.m_Model.findByName(this.m_FilterLayer);
        if (this.m_Layer.getOutputs().size() != 1) {
            throw new IllegalStateException("More than one output in layer '" + this.m_FilterLayer + "'!");
        }
        this.m_OutputVar = (Variable)this.m_Layer.getOutputs().get(0);
    }

    public float[] filter(float[] input) throws Exception {
        int i;
        if (!this.isInitialized()) {
            throw new Exception("Model not initialized!");
        }
        if (this.m_Ranges == null) {
            this.m_Ranges = new HashMap();
            for (i = 0; i < this.m_Inputs.length; ++i) {
                this.m_Inputs[i].setMax(input.length + 1);
                this.m_Ranges.put(this.m_InputNames[i].getValue(), new TIntHashSet(this.m_Inputs[i].getIntIndices()));
            }
        }
        HashMap<String, FloatVector> floatVecs = new HashMap<String, FloatVector>();
        block1: for (i = 0; i < input.length; ++i) {
            for (String name : this.m_Names) {
                TIntHashSet range = (TIntHashSet)this.m_Ranges.get(name);
                if (range == null || !range.contains(i)) continue;
                if (!floatVecs.containsKey(name)) {
                    floatVecs.put(name, new FloatVector());
                }
                ((FloatVector)floatVecs.get(name)).add(input[i]);
                continue block1;
            }
        }
        UnorderedMapVariableValuePtr inputDataMap = new UnorderedMapVariableValuePtr();
        for (String name : this.m_Names) {
            FloatVectorVector floatVecVec = new FloatVectorVector();
            floatVecVec.add((FloatVector)floatVecs.get(name));
            Value inputVal = Value.createDenseFloat((NDShape)((NDShape)this.m_InputShapes.get(name)), (FloatVectorVector)floatVecVec, (DeviceDescriptor)this.m_Device);
            inputDataMap.add((Variable)this.m_InputVars.get(name), inputVal);
        }
        UnorderedMapVariableValuePtr outputDataMap = new UnorderedMapVariableValuePtr();
        outputDataMap.add(this.m_OutputVar, null);
        this.m_Model.evaluate(inputDataMap, outputDataMap, this.m_Device);
        FloatVectorVector outputBuffer = new FloatVectorVector();
        outputDataMap.getitem(this.m_OutputVar).copyVariableValueToFloat(this.m_OutputVar, outputBuffer);
        FloatVector results = outputBuffer.get(0);
        float[] result = new float[(int)results.size()];
        for (i = 0; i < result.length; ++i) {
            result[i] = results.get(i);
        }
        for (FloatVector floatVec : floatVecs.values()) {
            floatVec.delete();
        }
        outputBuffer.delete();
        inputDataMap.delete();
        outputDataMap.delete();
        results.delete();
        return result;
    }
}

