/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.autodiff.samediff.internal;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.NonNull;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.Listener;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.VariableType;
import org.nd4j.autodiff.samediff.config.ExecutionResult;
import org.nd4j.autodiff.samediff.config.SDValue;
import org.nd4j.autodiff.samediff.config.SDValueType;
import org.nd4j.autodiff.samediff.internal.AbstractDependencyTracker;
import org.nd4j.autodiff.samediff.internal.AbstractSession;
import org.nd4j.autodiff.samediff.internal.FrameIter;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.autodiff.samediff.internal.SessionMemMgr;
import org.nd4j.autodiff.samediff.internal.Variable;
import org.nd4j.autodiff.samediff.internal.memory.ArrayCacheMemoryMgr;
import org.nd4j.autodiff.samediff.internal.memory.HashDependencyTracker;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.primitives.Pair;
import org.nd4j.common.util.ArrayUtil;
import org.nd4j.imports.VariableUtils;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseOp;
import org.nd4j.linalg.api.ops.BaseReduceOp;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.OpContext;
import org.nd4j.linalg.api.ops.ReduceOp;
import org.nd4j.linalg.api.ops.ScalarOp;
import org.nd4j.linalg.api.ops.custom.Invoke;
import org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Enter;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Exit;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.LoopCond;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.NextIteration;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch;
import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
import org.nd4j.linalg.api.ops.impl.shape.Concat;
import org.nd4j.linalg.api.ops.impl.shape.CreateView;
import org.nd4j.linalg.api.ops.impl.shape.Stack;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.BaseTensorOp;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArray;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArrayConcat;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArrayGather;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArrayRead;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArrayRemove;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArrayScatter;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArraySize;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArraySplit;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArrayWrite;
import org.nd4j.linalg.api.ops.impl.transforms.Assert;
import org.nd4j.linalg.api.ops.impl.transforms.custom.Assign;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker;
import org.nd4j.linalg.api.ops.impl.transforms.same.Identity;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.shade.wstx.util.StringUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class InferenceSession
extends AbstractSession<INDArray, Pair<SameDiffOp, OpContext>> {
    private static final Logger log = LoggerFactory.getLogger(InferenceSession.class);
    private static final String SCOPE_PANIC_MSG = "If required, arrays in workspaces can be detached using INDArray.detach() before being passed to the SameDiff instance.\nAlternatively, arrays defined in a workspace must be replaced after the workspace has been closed.";
    protected static final String KERAS_TRAIN_TEST = "keras_learning_phase";
    protected Set<Long> freedArrays = new LinkedHashSet<Long>();
    private SessionMemMgr mmgr;
    private AbstractDependencyTracker<SDValue, Dep> arrayUseTracker = new HashDependencyTracker<SDValue, Dep>();
    private Map<String, OpContext> opContexts = new HashMap<String, OpContext>();

    public InferenceSession(@NonNull SameDiff sameDiff) {
        super(sameDiff);
        if (sameDiff == null) {
            throw new NullPointerException("sameDiff is marked non-null but is null");
        }
        this.mmgr = new ArrayCacheMemoryMgr();
    }

    @Override
    protected Map<String, INDArray> preprocessPlaceholders(Map<String, INDArray> placeholders, At at) {
        this.arrayUseTracker.clear();
        for (SDVariable v : this.sameDiff.variables()) {
            if (v.getVariableType() == VariableType.CONSTANT) {
                this.arrayUseTracker.addDependency(SDValue.create(v.getArr()), new ConstantDep(v.name()));
                continue;
            }
            if (v.getVariableType() != VariableType.VARIABLE) continue;
            this.arrayUseTracker.addDependency(SDValue.create(v.getArr()), new VariableDep(v.name()));
        }
        boolean kerasWorkaround = false;
        List<String> phs = this.sameDiff.inputs();
        if (phs != null && !phs.isEmpty()) {
            for (String s : phs) {
                if (!s.endsWith(KERAS_TRAIN_TEST) || placeholders.containsKey(s)) continue;
                INDArray scalar = this.mmgr.allocate(false, DataType.BOOL, new long[0]).assign(at.operation().isTrainingPhase());
                placeholders = new HashMap<String, INDArray>(placeholders);
                placeholders.put(s, scalar);
                kerasWorkaround = true;
            }
        }
        if (placeholders == null || placeholders.isEmpty()) {
            return placeholders;
        }
        HashMap<String, INDArray> out = new HashMap<String, INDArray>();
        for (Map.Entry<String, INDArray> e : placeholders.entrySet()) {
            Preconditions.checkState((boolean)this.sameDiff.hasVariable(e.getKey()), (String)"Invalid placeholder passed for execution: No variable/placeholder with name %s exists", (Object)e.getKey());
            INDArray arr = e.getValue();
            SDValue arrValue = SDValue.create(arr);
            if (arr.isAttached()) {
                MemoryWorkspace ws;
                MemoryWorkspace memoryWorkspace = ws = arr.data() == null ? null : arr.data().getParentWorkspace();
                if (ws != null && ws.getWorkspaceType() != MemoryWorkspace.Type.CIRCULAR) {
                    if (!ws.isScopeActive()) {
                        throw new ND4JIllegalStateException("Placeholder \"" + e.getKey() + "\" array uses leaked workspace pointer from workspace [" + ws.getId() + "]: Workspace the array was defined in is no longer open.\nAll open workspaces: " + DefaultOpExecutioner.allOpenWorkspaces() + "\nIf required, arrays in workspaces can be detached using INDArray.detach() before being passed to the SameDiff instance.\nAlternatively, arrays defined in a workspace must be replaced after the workspace has been closed.");
                    }
                    if (ws.getGenerationId() != arr.data().getGenerationId()) {
                        throw new ND4JIllegalStateException("Placeholder \"" + e.getKey() + "\" array uses outdated workspace pointer from workspace [" + ws.getId() + "]: Workspace array was defined in has been closed and reopened at least once since array creation. Array WS iteration: " + arr.data().getGenerationId() + ". Workspace current iteration: " + ws.getGenerationId() + "\nAll open workspaces: " + DefaultOpExecutioner.allOpenWorkspaces() + "\nIf required, arrays in workspaces can be detached using INDArray.detach() before being passed to the SameDiff instance.\nAlternatively, arrays defined in a workspace must be replaced after the workspace has been closed.");
                    }
                }
            }
            DataType dt = this.sameDiff.getVariable(e.getKey()).dataType();
            if (kerasWorkaround && e.getKey().endsWith(KERAS_TRAIN_TEST)) {
                this.arrayUseTracker.addDependency(arrValue, new ExecDoneDep());
            } else if (arr.dataType() == dt) {
                this.arrayUseTracker.addDependency(arrValue, new PlaceholderDep(e.getKey()));
            } else {
                INDArray cast = this.mmgr.allocate(false, dt, arr.shape());
                cast.assign(arr);
                arr = cast;
                this.arrayUseTracker.addDependency(arrValue, new ExecDoneDep());
            }
            out.put(e.getKey(), arr);
        }
        return out;
    }

    @Override
    protected Map<String, SDValue> postProcessOutputValues(Map<String, SDValue> output) {
        if (this.dt.hasNewAllSatisfied()) {
            List execSteps = this.dt.getNewAllSatisfiedList();
            for (AbstractSession.ExecStep es : execSteps) {
                if (es.getType() != AbstractSession.ExecType.OP) continue;
                OpDep od = new OpDep(es.getName(), es.getFrameIter().getFrame(), es.getFrameIter().getIteration(), es.getFrameIter().getParentFrame());
                this.arrayUseTracker.markSatisfied(od, true);
            }
        }
        this.arrayUseTracker.markSatisfied(new ExecDoneDep(), true);
        if (this.arrayUseTracker.hasNewAllSatisfied()) {
            List<SDValue> l = this.arrayUseTracker.getNewAllSatisfiedList();
            block5: for (SDValue value : l) {
                switch (value.getSdValueType()) {
                    case LIST: {
                        for (INDArray arr : value.getListValue()) {
                            if (arr == null || this.freedArrays.contains(arr.getId()) || !this.sameDiff.isEnableCache()) continue;
                            this.mmgr.release(arr);
                            this.freedArrays.add(arr.getId());
                        }
                        continue block5;
                    }
                    case TENSOR: {
                        if (this.freedArrays.contains(value.getTensorValue().getId()) || !this.sameDiff.isEnableCache()) break;
                        this.mmgr.release(value.getTensorValue());
                        this.freedArrays.add(value.getTensorValue().getId());
                    }
                }
            }
        }
        return output;
    }

    @Override
    protected Map<String, INDArray> postProcessOutput(Map<String, INDArray> output) {
        return output;
    }

    @Override
    public ExecutionResult getOutputs(Pair<SameDiffOp, OpContext> opPair, FrameIter outputFrameIter, Set<AbstractSession.VarId> opInputs, Set<AbstractSession.VarId> allIterInputs, Set<String> constAndPhInputs, List<Listener> listeners, At at, MultiDataSet batch, Set<String> allReqVariables, Map<String, SDValue> otherPlaceHolders) {
        SameDiffOp op = (SameDiffOp)opPair.getFirst();
        at.setFrameIter(outputFrameIter);
        if (listeners != null && listeners.size() > 0) {
            SameDiffOp sdOp = this.sameDiff.getOps().get(op.getOp().getOwnName());
            for (Listener l : listeners) {
                if (!l.isActive(at.operation())) continue;
                l.preOpExecution(this.sameDiff, at, sdOp, (OpContext)opPair.getSecond());
            }
        }
        if (this.sameDiff.isDebugMode()) {
            log.info("Executing samediff op: " + op.getName());
        }
        ExecutionResult out = this.doExec(op.getOp(), (OpContext)opPair.getRight(), outputFrameIter, opInputs, allIterInputs, constAndPhInputs, otherPlaceHolders);
        List<String> opOutNames = op.getOutputsOfOp();
        if (log.isTraceEnabled()) {
            StringBuilder sb = new StringBuilder();
            sb.append(op.getName()).append(" - ").append(outputFrameIter).append(" outputs: ");
            for (int i = 0; i < out.numResults(); ++i) {
                if (i > 0) {
                    sb.append(", ");
                }
                if (out.hasSingle()) {
                    sb.append("(").append(i).append(" - ").append(opOutNames.get(i)).append(" = ").append(out.resultAt(i) == null ? null : Long.valueOf(out.resultAt(i).getId())).append(")");
                    continue;
                }
                if (!out.hasValues()) continue;
                SDValue value = out.valueWithKeyAtIndex(i, false);
                String append = value != null && value.getSdValueType() == SDValueType.LIST ? StringUtil.concatEntries((Collection)value.getListValue().stream().map(input -> input.getId()).collect(Collectors.toList()), (String)",", (String)",") : (value != null ? String.valueOf(value.getTensorValue().getId()) : null);
                sb.append("(").append(i).append(" - ").append(opOutNames.get(i)).append(" = ").append(value == null ? null : append).append(")");
            }
            log.trace(sb.toString());
        }
        if (listeners != null && listeners.size() > 0) {
            Map namedOuts = null;
            for (Listener l : listeners) {
                if (!l.isActive(at.operation())) continue;
                if (namedOuts == null) {
                    HashMap<String, INDArray> namedOutsBuilder = new HashMap<String, INDArray>();
                    for (int i = 0; i < out.numResults(); ++i) {
                        namedOutsBuilder.put(op.outputsOfOp.get(i), out.resultAt(i));
                    }
                    namedOuts = Collections.unmodifiableMap(namedOutsBuilder);
                }
                l.opExecution(this.sameDiff, at, batch, op, (OpContext)opPair.getSecond(), out.outputsToArray(opOutNames));
                for (String varName : namedOuts.keySet()) {
                    l.activationAvailable(this.sameDiff, at, batch, op, varName, (INDArray)namedOuts.get(varName));
                }
            }
        }
        op.getOp().clearArrays();
        if (opPair.getSecond() != null) {
            ((OpContext)opPair.getSecond()).purge();
        }
        SameDiffOp o = this.sameDiff.getOps().get(op.getName());
        List<String> outVarNames = o.getOutputsOfOp();
        for (int i = 0; i < out.numResults(); ++i) {
            Object array;
            if (out.hasSingle() && out.resultAt(i) == null || out.hasValues() && out.valueWithKeyAtIndex(i, false) == null && o.getOp() instanceof Switch) continue;
            String name = outVarNames.get(i);
            Variable v = (Variable)this.sameDiff.getVariables().get((Object)name);
            List<String> inputsForOps = v.getInputsForOp();
            if (inputsForOps != null) {
                for (String opName : inputsForOps) {
                    OpDep d;
                    Dep d2;
                    if (!this.subgraphOps.contains(opName)) continue;
                    SameDiffOp forOp = this.sameDiff.getOps().get(opName);
                    if (forOp.getOp() instanceof Enter) {
                        Enter e = (Enter)forOp.getOp();
                        if (e.isConstant()) {
                            d2 = new ExecDoneDep();
                            this.addToArrayTracker(out, i, d2);
                            continue;
                        }
                        d2 = new OpDep(opName, e.getFrameName(), 0, outputFrameIter);
                        this.addToArrayTracker(out, i, d2);
                        continue;
                    }
                    if (forOp.getOp() instanceof NextIteration) {
                        d = new OpDep(opName, outputFrameIter.getFrame(), outputFrameIter.getIteration() + 1, outputFrameIter.getParentFrame());
                        this.addToArrayTracker(out, i, d);
                        continue;
                    }
                    if (forOp.getOp() instanceof Exit) {
                        FrameIter fi = outputFrameIter.getParentFrame();
                        d2 = new OpDep(opName, fi.getFrame(), fi.getIteration(), fi.getParentFrame());
                        this.addToArrayTracker(out, i, d2);
                        continue;
                    }
                    d = new OpDep(opName, outputFrameIter.getFrame(), outputFrameIter.getIteration(), outputFrameIter.getParentFrame());
                    this.addToArrayTracker(out, i, d);
                }
            }
            if ("main".equals(outputFrameIter.getFrame()) && allReqVariables.contains(name)) {
                this.addToArrayTracker(out, i, new ReqOutputDep(name));
                continue;
            }
            if ((inputsForOps == null || inputsForOps.isEmpty()) && out.getValueOutputs() != null && !this.arrayUseTracker.hasDependency(out.valueWithKeyAtIndex(i, false))) {
                array = out.valueWithKeyAtIndex(i, false);
                if (log.isTraceEnabled() && array != null && ((SDValue)array).getTensorValue() != null) {
                    log.trace("Found array id {} (output of {}) not required anywhere, deallocating", (Object)((SDValue)array).getTensorValue().getId(), (Object)o.getName());
                }
                if (array == null || ((SDValue)array).getTensorValue() == null || this.freedArrays.contains(((SDValue)array).getTensorValue().getId())) continue;
                this.mmgr.release(((SDValue)array).getTensorValue());
                this.freedArrays.add(((SDValue)array).getTensorValue().getId());
                continue;
            }
            if (inputsForOps != null && !inputsForOps.isEmpty() || out.getOutputs() == null || this.arrayUseTracker.hasDependency(SDValue.create(out.resultAt(i)))) continue;
            array = out.resultAt(i);
            if (log.isTraceEnabled() && array != null && array != null) {
                log.trace("Found array id {} (output of {}) not required anywhere, deallocating", (Object)array.getId(), (Object)o.getName());
            }
            if (array == null || this.freedArrays.contains(array.getId())) continue;
            this.mmgr.release((INDArray)array);
            this.freedArrays.add(array.getId());
        }
        OpDep d = new OpDep(op.getName(), outputFrameIter.getFrame(), outputFrameIter.getIteration(), outputFrameIter.getParentFrame());
        this.arrayUseTracker.markSatisfied(d, true);
        if (this.arrayUseTracker.hasNewAllSatisfied()) {
            List<SDValue> canClose = this.arrayUseTracker.getNewAllSatisfiedList();
            for (SDValue value : canClose) {
                if (log.isTraceEnabled() && value.getSdValueType() == SDValueType.TENSOR) {
                    INDArray arr = value.getTensorValue();
                    log.trace("Closing array... id={}, {}", (Object)arr.getId(), (Object)arr.shapeInfoToString());
                }
                if (op.getOp() instanceof Switch) continue;
                switch (value.getSdValueType()) {
                    case TENSOR: {
                        if (this.freedArrays.contains(value.getTensorValue().getId()) || !this.sameDiff.isEnableCache()) break;
                        this.mmgr.release(value.getTensorValue());
                        this.freedArrays.add(value.getTensorValue().getId());
                        break;
                    }
                    case LIST: {
                        for (INDArray arr : value.getListValue()) {
                            if (arr == null || this.freedArrays.contains(arr.getId()) || !this.sameDiff.isEnableCache()) continue;
                            this.mmgr.release(arr);
                            this.freedArrays.add(arr.getId());
                        }
                        break;
                    }
                }
            }
        }
        return out;
    }

    private void addToArrayTracker(ExecutionResult out, int i, Dep d) {
        if (out.hasSingle()) {
            this.arrayUseTracker.addDependency(SDValue.create(out.resultOrValueAt(i, false)), d);
        } else {
            this.arrayUseTracker.addDependency(out.valueWithKeyAtIndex(i, false), d);
        }
    }

    public ExecutionResult doExec(DifferentialFunction op, OpContext opContext, FrameIter outputFrameIter, Set<AbstractSession.VarId> opInputs, Set<AbstractSession.VarId> allIterInputs, Set<String> constAndPhInputs, Map<String, SDValue> otherPlaceHolders) {
        boolean constPhInput;
        int totalInputs = (opInputs == null ? 0 : opInputs.size()) + (constAndPhInputs == null ? 0 : constAndPhInputs.size()) + (allIterInputs == null ? 0 : allIterInputs.size());
        boolean bl = constPhInput = !(opInputs != null && opInputs.size() != 0 || allIterInputs != null && allIterInputs.size() != 0);
        if (op instanceof Identity) {
            Identity i = (Identity)op;
            String[] argNames = i.argNames();
            Preconditions.checkState((argNames.length == 1 ? 1 : 0) != 0, (String)"Expected only 1 arg name in identity op, got %s", (Object)argNames);
            AbstractSession.VarId vid = outputFrameIter.toVarId(argNames[0]);
            SDValue sDValue = this.getSdValue(vid);
            return ExecutionResult.createValue(vid.getVariable(), sDValue);
        }
        if (op instanceof Switch) {
            INDArray predicate;
            Switch s = (Switch)op;
            String[] argNames = s.argNames();
            AbstractSession.VarId vidPredicate = outputFrameIter.toVarId(argNames[1]);
            SDValue sDValue = this.getSdValue(vidPredicate);
            INDArray iNDArray = predicate = sDValue.getSdValueType() == SDValueType.LIST ? sDValue.getListValue().get(0) : sDValue.getTensorValue();
            if (predicate != null && predicate.isEmpty()) {
                predicate = Nd4j.scalar(false);
            }
            if (predicate == null && !constAndPhInputs.isEmpty() && constAndPhInputs.contains(argNames[1])) {
                predicate = this.getTensorFromOutputs(new AbstractSession.VarId(argNames[1], "main", 0, null));
            }
            Preconditions.checkNotNull((Object)predicate, (String)"Error during graph execution: Predicate array was null. VarId=%s", (Object)vidPredicate);
            Preconditions.checkState((predicate.isScalar() && predicate.dataType() == DataType.BOOL ? 1 : 0) != 0, (String)"Expected boolean predicate: got %ndSInfo", (Object)predicate);
            AbstractSession.VarId vid = outputFrameIter.toVarId(argNames[0]);
            SDValue sdValue = this.getSdValue(vid);
            LinkedHashMap<String, SDValue> values = new LinkedHashMap<String, SDValue>();
            ExecutionResult.ExecutionResultBuilder executionResultBuilder = ExecutionResult.builder().valueOutputs(values);
            if (predicate.getDouble(0L) == 0.0) {
                if (vid.getVariable().equals(vidPredicate.getVariable())) {
                    SDValue sdValue1 = SDValue.create(Arrays.asList(sdValue.getTensorValue(), null));
                    values.put(vidPredicate.getVariable(), sdValue1);
                    this.putNodeValue(sdValue1, vid);
                    AbstractSession.VarId varId1 = new AbstractSession.VarId(vid.getVariable() + ":1", vid.getFrame(), vid.getIteration(), vid.getParentFrame());
                    this.putNodeValue(sdValue1, varId1);
                } else {
                    values.put(vid.getVariable(), sdValue);
                    values.put(vidPredicate.getVariable(), null);
                }
            } else if (vid.getVariable().equals(vidPredicate.getVariable())) {
                SDValue sdValue1 = SDValue.create(Arrays.asList(null, sdValue.getTensorValue()));
                values.put(vidPredicate.getVariable(), sdValue1);
                values.put(vidPredicate.getVariable() + ":1", sdValue1);
            } else {
                values.put(vid.getVariable(), null);
                values.put(vidPredicate.getVariable(), sdValue);
            }
            return executionResultBuilder.build();
        }
        if (op instanceof Enter) {
            Enter e = (Enter)op;
            String[] input = e.argNames();
            Preconditions.checkState((input.length == 1 ? 1 : 0) != 0, (String)"Expected only 1 arg name for enter op: got %s", (Object)input);
            Preconditions.checkState((totalInputs == 1 ? 1 : 0) != 0, (String)"Expected exactly 1 op input for Enter op \"%s\", got %s+%s", (Object)e.getOwnName(), opInputs, constAndPhInputs);
            AbstractSession.VarId inputVarId = constPhInput ? new AbstractSession.VarId(constAndPhInputs.iterator().next(), "main", 0, null) : (allIterInputs != null && allIterInputs.size() > 0 ? allIterInputs.iterator().next() : opInputs.iterator().next());
            inputVarId.setVariable(VariableUtils.stripVarSuffix(inputVarId.getVariable()));
            if (this.nodeValueOutputs.containsKey(inputVarId)) {
                SDValue sDValue = this.getSdValue(inputVarId);
                if (sDValue != null && sDValue.getSdValueType() == SDValueType.LIST) {
                    return ExecutionResult.createValue(inputVarId.getVariable(), sDValue);
                }
                if (sDValue != null && sDValue.getSdValueType() == SDValueType.TENSOR) {
                    INDArray inArr = this.getTensorFromOutputs(inputVarId);
                    if (inArr == null) {
                        Preconditions.throwStateEx((String)"Could not find array for NextIteration operation %s with output %s (frame=%s, iteration=%s)", (Object[])new Object[]{op.getOwnName(), this.sameDiff.getOps().get(op.getOwnName()).getOutputsOfOp().get(0), outputFrameIter.getFrame(), outputFrameIter.getIteration()});
                    }
                    return ExecutionResult.createFrom(Arrays.asList(inputVarId.getVariable()), new INDArray[]{inArr});
                }
                throw new IllegalStateException("Illegal value type " + sDValue.getSdValueType() + " for input " + inputVarId);
            }
            INDArray iNDArray = this.getTensorFromOutputs(inputVarId);
            if (iNDArray == null) {
                Preconditions.throwStateEx((String)"Could not find array for Enter operation %s with output %s (frame=%s, iteration=%s)", (Object[])new Object[]{op.getOwnName(), this.sameDiff.getOps().get(op.getOwnName()).getOutputsOfOp().get(0), outputFrameIter.getFrame(), outputFrameIter.getIteration()});
            }
            return ExecutionResult.createFrom(Arrays.asList(inputVarId.getVariable()), new INDArray[]{iNDArray});
        }
        if (op instanceof Exit) {
            AbstractSession.VarId inputVarId = constPhInput ? new AbstractSession.VarId(constAndPhInputs.iterator().next(), "main", 0, null) : (allIterInputs != null && allIterInputs.size() > 0 ? allIterInputs.iterator().next() : opInputs.iterator().next());
            SDValue sdValue = this.getSdValue(inputVarId);
            return ExecutionResult.createValue(inputVarId.getVariable(), sdValue);
        }
        if (op instanceof NextIteration) {
            Preconditions.checkState((totalInputs == 1 ? 1 : 0) != 0, (String)"Expected exactly 1 op input for NextIteration: got %s+%s", opInputs, constAndPhInputs);
            AbstractSession.VarId in = allIterInputs != null && !allIterInputs.isEmpty() ? allIterInputs.iterator().next() : opInputs.iterator().next();
            Preconditions.checkState((boolean)outputFrameIter.getFrame().equals(in.getFrame()), (String)"Expected same frame for NextIteration input vs. output: got input %s, output %s", (Object)in, (Object)outputFrameIter);
            Preconditions.checkState((outputFrameIter.getIteration() == in.getIteration() + 1 ? 1 : 0) != 0, (String)"Expected output iteration for NextIteration output to be 1 larger than the input iteration. Input: %s, output %s", (Object)in, (Object)outputFrameIter);
            if (this.nodeValueOutputs.containsKey(in) && this.getSdValue(in) != null) {
                SDValue value = this.getSdValue(in);
                if (value != null && value.getSdValueType() == SDValueType.LIST) {
                    return ExecutionResult.createValue(in.getVariable(), value);
                }
                if (value != null && value.getSdValueType() == SDValueType.TENSOR) {
                    INDArray inArr = this.getTensorFromOutputs(in);
                    if (inArr == null) {
                        Preconditions.throwStateEx((String)"Could not find array for NextIteration operation %s with output %s (frame=%s, iteration=%s)", (Object[])new Object[]{op.getOwnName(), this.sameDiff.getOps().get(op.getOwnName()).getOutputsOfOp().get(0), outputFrameIter.getFrame(), outputFrameIter.getIteration()});
                    }
                    return ExecutionResult.createFrom(Arrays.asList(in.getVariable()), new INDArray[]{inArr});
                }
                throw new IllegalStateException("Illegal value type " + value.getSdValueType() + " for input " + in);
            }
            INDArray inArr = this.getTensorFromOutputs(in);
            if (inArr == null) {
                Preconditions.throwStateEx((String)"Could not find array for NextIteration operation %s with output %s (frame=%s, iteration=%s)", (Object[])new Object[]{op.getOwnName(), this.sameDiff.getOps().get(op.getOwnName()).getOutputsOfOp().get(0), outputFrameIter.getFrame(), outputFrameIter.getIteration()});
            }
            return ExecutionResult.createFrom(Arrays.asList(in.getVariable()), new INDArray[]{inArr});
        }
        if (op instanceof Merge) {
            AbstractSession.VarId vid;
            Merge m = (Merge)op;
            Object[] in = this.sameDiff.getInputsForOp(op);
            AbstractSession.VarId firstInput = outputFrameIter.toVarId(in[0]);
            AbstractSession.VarId varId = outputFrameIter.toVarId(in[1]);
            SDValue firstValue = this.getSdValue(firstInput);
            SDValue secondValue = this.getSdValue(varId);
            String s = secondValue != null ? in[1] : in[0];
            AbstractSession.VarId varId2 = vid = secondValue != null ? varId : firstInput;
            if (firstValue == null && secondValue == null) {
                throw new IllegalStateException("Merge node " + m.getOwnName() + " has no available inputs (all inputs: " + Arrays.toString(in) + ") - should not be executed at this point");
            }
            log.trace("Returning input \"{}\" for merge node \"{}\"", (Object)m.getOwnName(), (Object)s);
            SDValue value = this.getSdValue(vid);
            if (value.getSdValueType() == SDValueType.LIST) {
                return ExecutionResult.createValue(vid.getVariable(), this.getSdValue(vid));
            }
            if (value.getSdValueType() == SDValueType.TENSOR) {
                INDArray inArr = this.getTensorFromOutputs(vid);
                if (inArr == null) {
                    Preconditions.throwStateEx((String)"Could not find array for NextIteration operation %s with output %s (frame=%s, iteration=%s)", (Object[])new Object[]{op.getOwnName(), this.sameDiff.getOps().get(op.getOwnName()).getOutputsOfOp().get(0), outputFrameIter.getFrame(), outputFrameIter.getIteration()});
                }
                return ExecutionResult.createFrom(Arrays.asList(vid.getVariable()), new INDArray[]{inArr});
            }
            throw new IllegalStateException("Illegal value type " + value.getSdValueType() + " for input " + (String[])in);
        }
        if (op instanceof LoopCond) {
            LoopCond lc = (LoopCond)op;
            String[] argNames = lc.argNames();
            Preconditions.checkState((argNames.length == 1 ? 1 : 0) != 0, (String)"Expected only 1 arg name in LoopCond op, got %s", (Object)argNames);
            AbstractSession.VarId vid = outputFrameIter.toVarId(argNames[0]);
            SDValue sDValue = this.getSdValue(vid);
            if (sDValue.getTensorValue() == null) {
                throw new IllegalStateException("Node value output at " + vid.getVariable() + " was not a boolean tensor!");
            }
            Preconditions.checkNotNull((Object)sDValue, (String)"Input to LoopCond op must not be null");
            Preconditions.checkState((sDValue.getTensorValue().isScalar() && sDValue.getTensorValue().dataType() == DataType.BOOL ? 1 : 0) != 0, (String)"LoopCond input must be a scalar boolean, got %ndShape");
            return ExecutionResult.createValue(vid.getVariable(), sDValue);
        }
        if (op instanceof BaseTensorOp) {
            return this.getOutputsHelperTensorArrayOps(op, outputFrameIter, opInputs, allIterInputs, otherPlaceHolders);
        }
        if (op instanceof Identity) {
            ArrayList<AbstractSession.VarId> orderedInputs = new ArrayList<AbstractSession.VarId>(opInputs);
            SDValue sdValue = this.getSdValue((AbstractSession.VarId)orderedInputs.get(0));
            return ExecutionResult.createValue(op.outputVariablesNames()[0], sdValue);
        }
        if (op instanceof Assign) {
            SDValue sdValue;
            ArrayList<AbstractSession.VarId> orderedInputs = new ArrayList<AbstractSession.VarId>(opInputs);
            if (orderedInputs.size() > 1) {
                sdValue = this.getSdValue((AbstractSession.VarId)orderedInputs.get(0));
                SDValue sdValue1 = this.getSdValue((AbstractSession.VarId)orderedInputs.get(1));
                switch (sdValue.getSdValueType()) {
                    case TENSOR: {
                        Assign assign = (Assign)op;
                        Nd4j.exec(assign, opContext);
                        return ExecutionResult.createFrom(assign, opContext);
                    }
                    case LIST: {
                        return ExecutionResult.createValue(op.outputVariablesNames()[0], sdValue1);
                    }
                }
            }
            sdValue = this.getSdValue((AbstractSession.VarId)orderedInputs.get(0));
            return ExecutionResult.createValue(op.outputVariablesNames()[0], sdValue);
        }
        if (op instanceof GradientBackwardsMarker) {
            INDArray out = this.mmgr.allocate(false, DataType.FLOAT, new long[0]).assign(Float.valueOf(1.0f));
            return ExecutionResult.createFrom(Arrays.asList("gradientbackwardsmarker"), new INDArray[]{out});
        }
        if (op instanceof CreateView) {
            LinkedHashMap<String, AbstractSession.VarId> inputVars = new LinkedHashMap<String, AbstractSession.VarId>();
            String[] argNames = op.argNames();
            for (AbstractSession.VarId varId : opInputs) {
                inputVars.put(varId.getVariable(), varId);
            }
            SDValue sdValue = this.getSdValue((AbstractSession.VarId)inputVars.get(argNames[0]));
            if (sdValue == null) {
                sdValue = SDValue.create(opContext.getInputArray(0));
            }
            INDArray[] iNDArrayArray = new INDArray[argNames.length - 1];
            for (int i = 1; i < argNames.length; ++i) {
                iNDArrayArray[i - 1] = this.getSdValue((AbstractSession.VarId)inputVars.get(argNames[i])).getTensorValue();
            }
            INDArray from = CreateView.createFrom(sdValue.getTensorValue(), iNDArrayArray);
            return ExecutionResult.createFrom(op.outputVariablesNames()[0], from);
        }
        if (op instanceof ExternalErrorsFunction) {
            ExternalErrorsFunction fn = (ExternalErrorsFunction)op;
            String n = fn.getGradPlaceholderName();
            INDArray arr = this.getTensorFromOutputs(new AbstractSession.VarId(n, "main", 0, null));
            Preconditions.checkState((arr != null ? 1 : 0) != 0, (String)"Could not find external errors placeholder array: %s", (Object)arr);
            INDArray iNDArray = this.mmgr.allocate(false, arr.dataType(), arr.shape());
            iNDArray.assign(arr);
            return ExecutionResult.createFrom(Arrays.asList(n), new INDArray[]{iNDArray});
        }
        if (op instanceof Invoke) {
            Invoke invoke = (Invoke)op;
            boolean hasValues = false;
            for (AbstractSession.VarId varId : opInputs) {
                if (!this.nodeValueOutputs.containsKey(varId)) continue;
                hasValues = true;
                break;
            }
            if (!hasValues) {
                for (Map.Entry entry : otherPlaceHolders.entrySet()) {
                    if (!constAndPhInputs.contains(entry.getKey())) continue;
                    hasValues = true;
                    break;
                }
            }
            LinkedHashMap<String, INDArray> inputs = new LinkedHashMap<String, INDArray>();
            LinkedHashMap<String, SDValue> linkedHashMap = new LinkedHashMap<String, SDValue>();
            if (!hasValues) {
                int currInput = 0;
                for (AbstractSession.VarId opInput : opInputs) {
                    inputs.put(opInput.getVariable(), opContext.getInputArray(currInput));
                    ++currInput;
                }
            } else {
                HashMap<String, AbstractSession.VarId> varIdsByVariable = new HashMap<String, AbstractSession.VarId>();
                for (AbstractSession.VarId opInput : opInputs) {
                    varIdsByVariable.put(opInput.getVariable(), opInput);
                }
                for (int i = 0; i < invoke.getInputVarNames().length; ++i) {
                    AbstractSession.VarId opInput;
                    opInput = (AbstractSession.VarId)varIdsByVariable.get(invoke.getInputVarNames()[i]);
                    if (constAndPhInputs.contains(invoke.getInputVarNames()[i])) {
                        if (otherPlaceHolders.containsKey(invoke.getInputVarNames()[i])) {
                            linkedHashMap.put(invoke.getInputVarNames()[i], otherPlaceHolders.get(invoke.getInputVarNames()[i]));
                            continue;
                        }
                        if (!inputs.containsKey(invoke.getInputVarNames()[i])) continue;
                        linkedHashMap.put(invoke.getInputVarNames()[i], SDValue.create((INDArray)inputs.get(invoke.getInputVarNames()[i])));
                        continue;
                    }
                    if (this.sameDiff.getArrForVarName(invoke.getInputVarNames()[i]) != null) {
                        linkedHashMap.put(invoke.getInputVarNames()[i], SDValue.create(this.sameDiff.getArrForVarName(invoke.getInputVarNames()[i])));
                        continue;
                    }
                    if (this.nodeValueOutputs.containsKey(opInput)) {
                        linkedHashMap.put(opInput.getVariable(), this.getSdValue(opInput));
                        continue;
                    }
                    linkedHashMap.put(opInput.getVariable(), SDValue.create(opContext.getInputArray(i)));
                }
            }
            if (linkedHashMap.size() + inputs.size() != op.args().length) {
                throw new IllegalArgumentException("Value inputs and inputs combined did not fulfill all arguments. Inputs were: " + Arrays.toString(op.argNames()) + " for op name " + op.getOwnName());
            }
            return Invoke.doInvoke(invoke, inputs, linkedHashMap);
        }
        if (op instanceof Assert) {
            boolean condition;
            Assert a = (Assert)op;
            boolean bl2 = condition = opContext.getInputArray(0).getDouble(0L) != 0.0;
            if (!condition) {
                INDArray iNDArray;
                String s = "Assertion failed for operation \"" + op.getOwnName() + "\" during execution";
                if (a.numInputArguments() >= 3 && (iNDArray = opContext.getInputArray(2)) != null && iNDArray.dataType() == DataType.UTF8) {
                    s = s + ": " + iNDArray.getString(0L);
                }
                if (a.numInputArguments() >= 5) {
                    INDArray iNDArray2 = opContext.getInputArray(4);
                    s = s + "\n" + iNDArray2;
                }
                throw new IllegalStateException(s);
            }
            return ExecutionResult.createFrom(a, opContext);
        }
        if (op instanceof CustomOp) {
            CustomOp c = (CustomOp)((Object)op);
            Nd4j.exec(c, opContext);
            return ExecutionResult.createFrom((DifferentialFunction)((Object)c), opContext);
        }
        if (op instanceof Op) {
            Op o = (Op)((Object)op);
            Nd4j.exec(o, opContext);
            return ExecutionResult.createFrom((DifferentialFunction)((Object)o), opContext);
        }
        throw new UnsupportedOperationException("Execution not yet implemented for: " + op.getClass().getName());
    }

    private SDValue getPreviousValue(AbstractSession.VarId varId) {
        return this.getPreviousValue(varId, 1);
    }

    private SDValue getPreviousValue(AbstractSession.VarId varId, int offset) {
        AbstractSession.VarId ret = new AbstractSession.VarId(varId.getVariable(), varId.getFrame(), varId.getIteration() - offset, varId.getParentFrame());
        return (SDValue)this.nodeValueOutputs.get(ret);
    }

    private SDValue getValueAtIteration(String var, String frame, int iteration, FrameIter parentFrame) {
        AbstractSession.VarId varId = new AbstractSession.VarId(var, frame, iteration, parentFrame);
        return (SDValue)this.nodeValueOutputs.get(varId);
    }

    public ExecutionResult getOutputsHelperTensorArrayOps(DifferentialFunction op, FrameIter outputFrameIter, Set<AbstractSession.VarId> opInputs, Set<AbstractSession.VarId> allIterInputs, Map<String, SDValue> otherPlaceHolders) {
        if (op instanceof TensorArray) {
            AbstractSession.VarId vid = outputFrameIter.toVarId(op.outputVariable().name());
            if (this.nodeValueOutputs.containsKey(vid)) {
                return ExecutionResult.createValue(vid.getVariable(), (SDValue)this.nodeValueOutputs.get(vid));
            }
            Preconditions.checkState((!this.nodeValueOutputs.containsKey(vid) ? 1 : 0) != 0, (String)"TensorArray already exists for %s when executing TensorArrayV3", (Object)vid);
            ArrayList<INDArray> createList = new ArrayList<INDArray>();
            if (op.args().length > 0) {
                SDVariable size = op.arg(0);
                INDArray arr = size.getArr();
                TensorArray tensorArray = (TensorArray)op;
                long[] requiredShape = tensorArray.args().length > 1 ? tensorArray.requiredShape() : null;
                for (int i = 0; i < arr.getInt(0); ++i) {
                    createList.add(null);
                }
            }
            SDValue listValue = SDValue.create(createList);
            this.putNodeValue(listValue, vid);
            return ExecutionResult.createValue(vid.getVariable(), listValue);
        }
        if (op instanceof TensorArrayRead) {
            AbstractSession.VarId v;
            SDVariable idxSDV = op.arg(1);
            INDArray idxArr = this.getArray(idxSDV, opInputs, allIterInputs);
            Preconditions.checkState((boolean)idxArr.isScalar(), (String)"TensorArrayRead input argument 1 should be scalar - has shape %ndShape", (Object)idxArr);
            int i = idxArr.getInt(0);
            SDVariable inTensorArray2 = op.arg(0);
            AbstractSession.VarId varId = v = opInputs == null ? null : InferenceSession.lookup(inTensorArray2.name(), opInputs, false);
            if (v == null && allIterInputs != null) {
                v = InferenceSession.lookup(inTensorArray2.name(), allIterInputs, false);
            }
            Preconditions.checkState((v != null ? 1 : 0) != 0, (String)"Could not find input %s", (Object)inTensorArray2.name());
            TensorArray tensorArray1 = TensorArray.getTensorArray(this.sameDiff, inTensorArray2);
            List<INDArray> list = null;
            if (!this.nodeValueOutputs.containsKey(v)) {
                TensorArray tensorArray = TensorArray.getTensorArray(this.sameDiff, inTensorArray2);
                SDVariable output = tensorArray.getVar();
                list = this.getTensorArraysInSession(output.name());
            } else {
                list = this.getSdValue(v).getListValue();
            }
            if (tensorArray1.args().length > 1) {
                long[] inputShapeArr = tensorArray1.requiredShape();
                for (int j = 0; j < list.size(); ++j) {
                    if (list.get(j) == null || Arrays.equals(inputShapeArr, list.get(j).shape())) continue;
                    throw new IllegalArgumentException("Element " + j + " of list " + v.getVariable() + " did not have correct shape of " + Arrays.toString(inputShapeArr) + " was shape " + Arrays.toString(list.get(j).shape()));
                }
            }
            Preconditions.checkState((list != null ? 1 : 0) != 0, (String)"Could not find TensorList for %s", (Object)v);
            Preconditions.checkState((list.size() > i ? 1 : 0) != 0, (String)"Cannot get index %s from TensorList of size %s (array not present?) - VarId=%s", (Object)i, (Object)list.size(), (Object)v);
            INDArray out = list.get(i);
            log.trace("Reading item at index " + i + " for list " + v + " with value " + out + " with list of " + list);
            return ExecutionResult.createFrom(v.getVariable(), out);
        }
        if (op instanceof TensorArrayWrite) {
            AbstractSession.VarId tArr;
            SDVariable inTensorArray = op.arg(0);
            AbstractSession.VarId varId = tArr = opInputs == null ? null : InferenceSession.lookup(inTensorArray.name(), opInputs, false);
            if (tArr == null && allIterInputs != null) {
                tArr = InferenceSession.lookup(inTensorArray.name(), allIterInputs, false);
            }
            if (tArr == null && inTensorArray.getVariableType() == VariableType.PLACEHOLDER) {
                AbstractSession.VarId varId2;
                tArr = varId2 = new AbstractSession.VarId(inTensorArray.name(), outputFrameIter.getFrame(), outputFrameIter.getIteration(), outputFrameIter.getParentFrame());
                SDValue inTensorArray2 = otherPlaceHolders.get(inTensorArray.name());
            }
            Preconditions.checkState((tArr != null ? 1 : 0) != 0, (String)"Could not find input %s", (Object)inTensorArray.name());
            String idxName = op.arg(1).name();
            SDVariable idxSDV = this.sameDiff.getVariable(idxName);
            INDArray idxArr = this.getArray(idxSDV, opInputs, allIterInputs);
            Preconditions.checkState((boolean)idxArr.isScalar(), (String)"Index variable ID for TensorArrayWrite should be a scalar, got %ndShape", (Object)idxArr);
            int idx = idxArr.getInt(0);
            String inName = op.arg(2).name();
            SDVariable inSDV = this.sameDiff.getVariable(inName);
            INDArray arr = this.getArray(inSDV, opInputs, allIterInputs);
            Preconditions.checkState((arr != null ? 1 : 0) != 0, (String)"Could not find array for %s", (Object)inName);
            TensorArray tArrOp = TensorArray.getTensorArray(this.sameDiff, inTensorArray);
            tArr = new AbstractSession.VarId(tArrOp.outputVariable().name(), "main", 0, null);
            if (tArrOp.args().length > 1) {
                long[] shape = tArrOp.arg(1).getArr().toLongVector();
                if (!Arrays.equals(arr.shape(), shape)) {
                    throw new IllegalArgumentException("Unable to write array of shape " + Arrays.toString(arr.shape()) + " must be " + shape + " for op " + op.getOwnName() + " and tensor array " + tArrOp.getOwnName());
                }
            }
            Preconditions.checkState((boolean)this.nodeValueOutputs.containsKey(tArr), (String)"Tensor array does not exist for %s", (Object)tArr);
            SDValue sdValue1 = this.getSdValue(tArr);
            List<INDArray> l = sdValue1.getListValue();
            if (idx < 0 && l != null && !l.isEmpty()) {
                idx += l.size() + 1;
            } else if (idx < 0) {
                idx = 0;
            }
            while (l.size() <= idx) {
                l.add(null);
            }
            this.setArrayAtIndex(l, idx, arr);
            log.trace("Setting item at index " + idx + " for list " + tArr + " with value " + arr + " with whole list of after write " + l + " and value array " + arr);
            log.trace("Writing value " + inSDV + " to list " + tArr.getVariable() + " at iteration " + tArr.getIteration());
            ExecDoneDep d = new ExecDoneDep();
            this.arrayUseTracker.addDependency(sdValue1, d);
            return ExecutionResult.createValue(op.outputVariable().name(), sdValue1);
        }
        if (op instanceof TensorArraySize) {
            List<INDArray> l;
            AbstractSession.VarId tArr;
            SDVariable inTensorArray = op.arg(0);
            TensorArray tensorArray = TensorArray.getTensorArray(this.sameDiff, inTensorArray);
            AbstractSession.VarId varId = tArr = opInputs == null ? null : InferenceSession.lookup(inTensorArray.name(), opInputs, false);
            if (tArr == null && allIterInputs != null) {
                tArr = InferenceSession.lookup(inTensorArray.name(), allIterInputs, false);
            }
            int size = (l = this.getSdValue(tArr).getListValue()) == null ? 0 : l.size();
            INDArray scalar = this.mmgr.allocate(false, DataType.INT, new long[0]).assign(size);
            return ExecutionResult.createFrom(tensorArray.getVar().name(), scalar);
        }
        if (op instanceof TensorArrayConcat) {
            AbstractSession.VarId tArr;
            SDVariable inTensorArray = op.arg(0);
            AbstractSession.VarId varId = tArr = opInputs == null ? null : InferenceSession.lookup(inTensorArray.name(), opInputs, false);
            if (tArr == null && allIterInputs != null) {
                tArr = InferenceSession.lookup(inTensorArray.name(), allIterInputs, false);
            }
            List<INDArray> l = this.getSdValue(tArr).getListValue();
            Concat c = new Concat(0, l.stream().filter(input -> input != null).collect(Collectors.toList()).toArray(new INDArray[0]));
            List<LongShapeDescriptor> shape = c.calculateOutputShape();
            INDArray out = this.mmgr.allocate(false, shape.get(0));
            c.setOutputArgument(0, out);
            Nd4j.exec(c);
            return ExecutionResult.createFrom(tArr.getVariable(), out);
        }
        if (op instanceof TensorArrayGather) {
            List<INDArray> l;
            AbstractSession.VarId tArr;
            SDVariable inTensorArray = op.arg(0);
            AbstractSession.VarId varId = tArr = opInputs == null ? null : InferenceSession.lookup(inTensorArray.name(), opInputs, false);
            if (tArr == null && allIterInputs != null) {
                tArr = InferenceSession.lookup(inTensorArray.name(), allIterInputs, false);
            }
            Preconditions.checkState(((l = this.getSdValue(tArr).getListValue()) != null ? 1 : 0) != 0, (String)"Could not find TensorArray: %s", (Object)tArr);
            String indicesName = op.arg(1).name();
            SDVariable indicesSDV = this.sameDiff.getVariable(indicesName);
            INDArray idxArr = indicesSDV.getArr();
            Preconditions.checkState((boolean)idxArr.isVector(), (String)"Indices variable for TensorArrayGather should be a vector, got %ndShape for %s", (Object)idxArr, (Object)indicesName);
            Preconditions.checkState((boolean)idxArr.dataType().isIntType(), (String)"Indices variable for TensorArrayGather should be an integer type, got %s for array %s", (Object)((Object)idxArr.dataType()), (Object)indicesName);
            int[] idxArrInt = idxArr.toIntVector();
            log.trace("Gathering op " + op.getOwnName() + " from indices " + Arrays.toString(idxArrInt) + " named " + indicesName + " from list " + tArr.getVariable());
            if (idxArrInt.length > 0) {
                ArrayList<INDArray> newList = new ArrayList<INDArray>();
                if (idxArrInt.length == 1 || idxArrInt.length > 0 && idxArrInt[0] < 0) {
                    newList.addAll(l);
                } else {
                    for (int id : idxArrInt) {
                        Preconditions.checkState((id >= 0 ? 1 : 0) != 0, (String)"Index for TensorArrayGather must be >= 0, got %s", (int)id);
                        if (l.get(id) == null) continue;
                        log.trace("Gathering op " + op.getOwnName() + " at index " + id + " adding value " + l.get(id).toStringFull() + " from full list " + l);
                        newList.add(l.get(id));
                    }
                }
                Stack s = new Stack(newList.stream().filter(input -> input != null).collect(Collectors.toList()).toArray(new INDArray[0]), null, 0);
                List<LongShapeDescriptor> shape = s.calculateOutputShape();
                INDArray out = this.mmgr.allocate(false, shape.get(0));
                s.setOutputArgument(0, out);
                Nd4j.exec(s);
                return ExecutionResult.createFrom(tArr.getVariable(), out);
            }
            return ExecutionResult.createFrom(tArr.getVariable(), Nd4j.zeros(op.arg().dataType(), 0L));
        }
        if (op instanceof TensorArrayScatter) {
            int i;
            SDValue retValue;
            List<INDArray> l;
            AbstractSession.VarId tArr;
            SDVariable inTensorArray = op.arg(0);
            TensorArray ta = TensorArray.getTensorArray(this.sameDiff, inTensorArray);
            AbstractSession.VarId varId = tArr = opInputs == null ? null : InferenceSession.lookup(ta.outputVariablesNames()[0], opInputs, false);
            if (tArr == null && allIterInputs != null) {
                tArr = InferenceSession.lookup(ta.outputVariablesNames()[0], allIterInputs, false);
            }
            Preconditions.checkState(((l = (retValue = this.getSdValue(tArr)).getListValue()) != null ? 1 : 0) != 0, (String)"Could not find TensorArray: %s", (Object)tArr);
            String indicesName = op.arg(1).name();
            SDVariable indicesSDV = this.sameDiff.getVariable(indicesName);
            INDArray idxArr = indicesSDV.getArr();
            Preconditions.checkState((boolean)idxArr.isVector(), (String)"Indices variable for TensorArrayScatter should be a vector, got %ndShape for %s", (Object)idxArr, (Object)indicesName);
            Preconditions.checkState((boolean)idxArr.dataType().isIntType(), (String)"Indices variable for TensorArrayScatter should be an integer type, got %s for array %s", (Object)((Object)idxArr.dataType()), (Object)indicesName);
            int[] idxs = idxArr.toIntVector();
            String valuesName = op.arg(2).name();
            SDVariable valuesSDV = this.sameDiff.getVariable(valuesName);
            INDArray valuesArr = this.getArray(valuesSDV, opInputs, allIterInputs);
            while (l.size() < idxs.length) {
                l.add(null);
            }
            if (idxs.length == 1 && idxs[0] == -1) {
                idxs = ArrayUtil.range((int)0, (int)((int)valuesArr.size(0)));
            }
            for (i = 0; i < idxs.length; ++i) {
                if (valuesArr.size(0) >= (long)idxs[i]) continue;
                throw new IllegalArgumentException("Unable to obtain slice from values array named " + valuesName + " with shape " + Arrays.toString(valuesArr.shape()) + " at index " + idxs[i] + " at node named " + op.getOwnName() + " with inputs " + Arrays.toString(op.argNames()));
            }
            for (i = 0; i < idxs.length; ++i) {
                if ((long)idxs[i] >= valuesArr.size(0)) {
                    throw new IllegalStateException("Unable to pull slice from value array " + valuesSDV.name() + " of shape " + Arrays.toString(valuesArr.shape()) + " index was" + idxs[i] + " all indices were " + Arrays.toString(idxs));
                }
                INDArray getView = valuesArr.slice(idxs[i]);
                INDArray get = this.mmgr.dup(getView);
                if (ta.args().length > 1) {
                    long[] shape = ta.arg(1).getArr().toLongVector();
                    if (!Arrays.equals(get.shape(), shape)) {
                        throw new IllegalArgumentException("Unable to write array of shape " + Arrays.toString(get.shape()) + " must be " + shape + " for op " + op.getOwnName() + " and tensor array " + ta.getOwnName());
                    }
                }
                SDValue newValue = SDValue.create(get);
                int outIdx = idxs[i];
                if (valuesArr.rank() == 1 && get.rank() > 0) {
                    get = get.reshape(new long[0]);
                }
                if (outIdx >= l.size()) {
                    while (l.size() <= outIdx) {
                        l.add(null);
                    }
                }
                log.trace("Scattering item at index " + i + " for list " + tArr + " with value " + get + " from whole list of " + l + " from values array " + valuesArr.toStringFull() + " named " + valuesSDV.name());
                this.setArrayAtIndex(l, outIdx, get);
                this.arrayUseTracker.addDependency(newValue, new ExecDoneDep());
            }
            return ExecutionResult.createValue(valuesName, retValue);
        }
        if (op instanceof TensorArraySplit) {
            AbstractSession.VarId tArr;
            SDVariable inTensorArray = op.arg(0);
            AbstractSession.VarId varId = tArr = opInputs == null ? null : InferenceSession.lookup(inTensorArray.name(), opInputs, false);
            if (tArr == null && allIterInputs != null) {
                tArr = InferenceSession.lookup(inTensorArray.name(), allIterInputs, false);
            }
            while (this.sameDiff.getVariableOutputOp(inTensorArray.name()) instanceof Enter) {
                inTensorArray = this.sameDiff.getVariableOutputOp(inTensorArray.name()).arg();
                tArr = tArr.getParentFrame().toVarId(inTensorArray.name());
            }
            SDValue sdValue = this.getSdValue(tArr);
            List<INDArray> l = sdValue.getListValue();
            Preconditions.checkState((l != null ? 1 : 0) != 0, (String)"Could not find TensorArray: %s", (Object)tArr);
            String splitName = op.arg(1).name();
            INDArray splitArr = this.getArray(this.sameDiff.getVariable(splitName), opInputs, allIterInputs);
            String sizeName = op.arg(2).name();
            SDVariable sizeSDV = this.sameDiff.getVariable(sizeName);
            INDArray sizeArr = this.getArray(sizeSDV, opInputs, allIterInputs);
            Preconditions.checkState((boolean)sizeArr.isVector(), (String)"Indices variable for TensorArraySplit should be a vector, got %ndShape for %s", (Object)sizeArr, (Object)sizeName);
            Preconditions.checkState((boolean)sizeArr.dataType().isIntType(), (String)"Indices variable for TensorArraySplit should be an integer type, got %s for array %s", (Object)((Object)sizeArr.dataType()), (Object)sizeName);
            int[] sizes = sizeArr.toIntVector();
            while (l.size() <= sizes.length) {
                l.add(null);
            }
            INDArrayIndex[] idx = (INDArrayIndex[])ArrayUtil.nTimes((int)splitArr.rank(), (Object)NDArrayIndex.all(), INDArrayIndex.class);
            int soFar = 0;
            for (int i = 0; i < sizes.length; ++i) {
                idx[0] = NDArrayIndex.interval(soFar, soFar + sizes[i]);
                INDArray sub = this.mmgr.dup(splitArr.get(idx));
                SDValue subValue = SDValue.create(sub);
                this.setArrayAtIndex(l, i, sub);
                soFar += sizes[i];
                this.arrayUseTracker.addDependency(subValue, new ExecDoneDep());
            }
            return ExecutionResult.createValue(sizeName, sdValue);
        }
        if (op instanceof TensorArrayRemove) {
            AbstractSession.VarId tArr;
            SDVariable inTensorArray = op.arg(0);
            SDVariable index = op.arg(1);
            List<INDArray> l = this.getTensorArraysInSession(inTensorArray.name());
            if (l == null) {
                l = new ArrayList<INDArray>();
            } else if (l != null) {
                l.remove(index.getArr(true).getInt(0));
            }
            AbstractSession.VarId varId = tArr = opInputs == null ? null : InferenceSession.lookup(inTensorArray.name(), opInputs, false);
            if (tArr == null && allIterInputs != null) {
                tArr = InferenceSession.lookup(inTensorArray.name(), allIterInputs, false);
            }
            while (this.sameDiff.getVariableOutputOp(inTensorArray.name()) instanceof Enter) {
                inTensorArray = this.sameDiff.getVariableOutputOp(inTensorArray.name()).arg();
                tArr = tArr.getParentFrame().toVarId(inTensorArray.name());
            }
            this.putNodeValue(SDValue.create(l), tArr);
            return ExecutionResult.createValue(tArr.getVariable(), l);
        }
        throw new IllegalStateException("Execution support not yet implemented for: " + op.getClass().getName());
    }

    private Map<Pair<String, Integer>, SDValue> valuesFor(String varName) {
        HashMap<Pair<String, Integer>, SDValue> ret = new HashMap<Pair<String, Integer>, SDValue>();
        for (Map.Entry values : this.nodeValueOutputs.entrySet()) {
            if (!((AbstractSession.VarId)values.getKey()).getVariable().equals(varName)) continue;
            ret.put((Pair<String, Integer>)Pair.of((Object)((AbstractSession.VarId)values.getKey()).getVariable(), (Object)((AbstractSession.VarId)values.getKey()).getIteration()), (SDValue)values.getValue());
        }
        return ret;
    }

    @Override
    public INDArray getConstantOrVariable(String variableName) {
        SDVariable v = this.sameDiff.getVariable(variableName);
        Preconditions.checkState((this.sameDiff.getVariable(variableName).isConstant() || v.getVariableType() == VariableType.VARIABLE ? 1 : 0) != 0, (String)"Variable %s is not a constant", (Object)variableName);
        return this.sameDiff.getArrForVarName(variableName);
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    @Override
    public Pair<SameDiffOp, OpContext> getAndParameterizeOp(String opName, FrameIter frameIter, Set<AbstractSession.VarId> opInputs, Set<AbstractSession.VarId> allIterInputs, Set<String> constAndPhInputs, Map<String, INDArray> placeholderValues, Set<String> allReqVariables, Map<String, SDValue> otherPlaceholders) {
        int numConstPhIns;
        SameDiffOp sdo = this.sameDiff.getOps().get(opName);
        DifferentialFunction df = sdo.getOp();
        Preconditions.checkNotNull((Object)df, (String)"No differential function found with name \"%s\"", (Object)opName);
        if (df instanceof LoopCond || df instanceof Enter || df instanceof Exit || df instanceof NextIteration || df instanceof Merge || df instanceof Switch || df instanceof BaseTensorOp || df instanceof Invoke) {
            return new Pair((Object)sdo, null);
        }
        String[] argNames = df.argNames();
        int numArgs = argNames == null ? 0 : argNames.length;
        int numNonConstIns = opInputs == null ? 0 : opInputs.size();
        int numNonConstInsAllIters = allIterInputs == null ? 0 : allIterInputs.size();
        int n = numConstPhIns = constAndPhInputs == null ? 0 : constAndPhInputs.size();
        if (numArgs != numNonConstIns + numConstPhIns + numNonConstInsAllIters) {
            if (numArgs > 1) {
                LinkedHashSet uniqueArgNames = new LinkedHashSet();
                Collections.addAll(uniqueArgNames, argNames);
            } else {
                Preconditions.checkState((numArgs == numNonConstIns + numConstPhIns ? 1 : 0) != 0, (String)"Different number of arg names as op inputs for op %s (%s): arg names %s vs. op inputs %s+%s", (Object)df.getClass().getSimpleName(), (Object)opName, (Object)argNames, opInputs, constAndPhInputs);
            }
        }
        INDArray[] args = null;
        if (argNames != null && argNames.length > 0) {
            args = new INDArray[argNames.length];
            int i = 0;
            for (String s : argNames) {
                SDVariable v = this.sameDiff.getVariable(s);
                if (v.isConstant()) {
                    args[i] = v.getArr();
                } else if (v.getVariableType() == VariableType.VARIABLE) {
                    args[i] = v.getArr();
                } else if (v.isPlaceHolder()) {
                    if (placeholderValues != null && placeholderValues.containsKey(s)) {
                        args[i] = placeholderValues.get(s);
                    } else {
                        if (otherPlaceholders == null || !otherPlaceholders.containsKey(s)) throw new IllegalArgumentException(String.format(s, new Object[0]));
                        args[i] = otherPlaceholders.get(s).getTensorValue();
                    }
                } else {
                    AbstractSession.VarId vid = InferenceSession.lookup(s, opInputs, allIterInputs, true);
                    SDValue getValue = this.getSdValue(vid);
                    if (getValue != null) {
                        block0 : switch (getValue.getSdValueType()) {
                            case TENSOR: {
                                args[i] = getValue.getTensorValue();
                                break;
                            }
                            case LIST: {
                                DifferentialFunction variableOutputOp = this.sameDiff.getVariableOutputOp(s);
                                if (variableOutputOp instanceof Switch && variableOutputOp.argNames().length == 2 && variableOutputOp.argNames()[0].equals(variableOutputOp.argNames()[1])) {
                                    for (int j = 0; j < getValue.getListValue().size(); ++j) {
                                        if (getValue.getListValue().get(j) == null) continue;
                                        args[i] = getValue.getListValue().get(j);
                                        break block0;
                                    }
                                    break;
                                }
                                args[i] = Nd4j.empty(DataType.FLOAT);
                            }
                        }
                    }
                }
                Preconditions.checkNotNull((Object)args[i], (String)"Could not parameterize op %s: array %s (variable %s) is null", (Object)opName, (Object)i, (Object)v.name());
                ++i;
            }
        }
        boolean isLoop = !frameIter.getFrame().equals("main") && frameIter.getIteration() > 0;
        OpContext oc = this.opContexts.get(opName);
        if (oc == null) {
            oc = Nd4j.getExecutioner().buildContext();
            this.opContexts.put(opName, oc);
        }
        if (df instanceof CustomOp) {
            DynamicCustomOp customOp = (DynamicCustomOp)df;
            if (df instanceof Identity || df instanceof CreateView) {
                if (args != null) {
                    oc.setInputArrays(args);
                }
                oc.setOutputArrays(args[0]);
                return new Pair((Object)sdo, (Object)oc);
            }
            oc.setArgs(args, customOp.iArgs(), customOp.dArgs(), customOp.tArgs(), customOp.bArgs());
            if (df instanceof Assign) {
                oc.setOutputArray(0, oc.getInputArray(0));
                return new Pair((Object)sdo, (Object)oc);
            } else {
                List<LongShapeDescriptor> outShape = customOp.calculateOutputShape(oc);
                Preconditions.checkState((outShape != null && outShape.size() > 0 ? 1 : 0) != 0, (String)"Failed to calculate output shapes for op %s (%s) - no shapes were returned by calculateOutputShape()", (Object)customOp.opName(), (Object)customOp.getOwnName());
                String[] outNames = df.outputVariablesNames();
                Preconditions.checkState((outNames.length == outShape.size() ? 1 : 0) != 0, (String)"Error in operation shape calculation for op \"%s\": Got %s op output shapes for an operation with %s outputs (number of shapes and outputs must be equal)", (Object)df.opName(), (Object)outShape.size(), (Object)outNames.length);
                for (int i = 0; i < outShape.size(); ++i) {
                    DataType currDT;
                    LongShapeDescriptor reqShape = outShape.get(i);
                    DataType dt = this.sameDiff.getVariable(outNames[i]).dataType();
                    if (dt != (currDT = reqShape.dataType())) {
                        reqShape = reqShape.asDataType(dt);
                    }
                    boolean isOutput = allReqVariables.contains(outNames[i]);
                    INDArray out = this.mmgr.allocate(isOutput, reqShape);
                    if (reqShape.isEmpty() && !out.isEmpty()) {
                        throw new IllegalStateException("Output shape was empty, but created array was not.");
                    }
                    oc.setOutputArray(i, out);
                }
            }
            return new Pair((Object)sdo, (Object)oc);
        } else {
            if (!(df instanceof Op)) return new Pair((Object)sdo, (Object)oc);
            Op op = (Op)((Object)df);
            boolean axisArg = false;
            boolean emptyReduce = false;
            if (op instanceof ReduceOp && ((ReduceOp)op).getOpType() != Op.Type.REDUCE3 && df.argNames().length == 2) {
                SDVariable axisArgVar = df.arg(1);
                Preconditions.checkState((boolean)axisArgVar.dataType().isIntType(), (String)"Legacy op %s input 1 (axis) was expected to be an integer type, is %s", df.getClass(), (Object)((Object)axisArgVar.dataType()));
                INDArray arr = this.getArray(axisArgVar, opInputs, allIterInputs);
                Preconditions.checkState((arr != null ? 1 : 0) != 0, (String)"Could not get axis argument for op %s: %s", (Object)df.getOwnName(), df.getClass());
                if (!arr.isEmpty()) {
                    int[] axis = arr.toIntVector();
                    int rank = args[0].rank();
                    axis = Shape.normalizeAxis(rank, axis);
                    df.setDimensions(axis);
                    ((BaseReduceOp)op).setEmptyReduce(false);
                } else {
                    df.setDimensions(null);
                    emptyReduce = true;
                    ((BaseReduceOp)op).setEmptyReduce(true);
                }
                axisArg = true;
            } else if (op instanceof ScalarOp && df.argNames().length == 2) {
                SDVariable scalarVar = df.arg(1);
                INDArray scalar = this.getArray(scalarVar, opInputs, allIterInputs);
                Preconditions.checkState((scalar != null ? 1 : 0) != 0, (String)"Could not get scalar argument for op %s: %s", (Object)df.getOwnName(), df.getClass());
                Preconditions.checkState((boolean)scalar.isScalar(), (String)"Scalar argument for op %s (%s) is not a scalar: has shape %ndShape", (Object)df.getOwnName(), df.getClass(), (Object)scalar);
                ((ScalarOp)op).setScalar(scalar);
            }
            if (args != null && args.length > 0) {
                oc.setInputArray(0, args[0]);
                if (args.length == 2 && !axisArg) {
                    oc.setInputArray(1, args[1]);
                }
            }
            boolean isOutput = allReqVariables.contains(((BaseOp)op).outputVariablesNames()[0]);
            if (emptyReduce) {
                INDArray z = this.mmgr.allocate(false, oc.getInputArray(0).dataType(), oc.getInputArray(0).shape());
                oc.setOutputArray(0, z);
                return new Pair((Object)sdo, (Object)oc);
            } else {
                List<LongShapeDescriptor> outputShape = ((BaseOp)op).calculateOutputShape(oc);
                Preconditions.checkState((outputShape != null && outputShape.size() == 1 ? 1 : 0) != 0, (String)"Could not calculate output shape for op: %s", op.getClass());
                LongShapeDescriptor lsd = outputShape.get(0);
                INDArray z = this.mmgr.allocate(isOutput, lsd);
                oc.setOutputArray(0, z);
            }
        }
        return new Pair((Object)sdo, (Object)oc);
    }

    protected INDArray getArray(SDVariable sdv, Collection<AbstractSession.VarId> opInputs, Collection<AbstractSession.VarId> allIterInputs) {
        String n = sdv.name();
        if (sdv.getVariableType() == VariableType.CONSTANT || sdv.getVariableType() == VariableType.VARIABLE) {
            return this.getConstantOrVariable(n);
        }
        AbstractSession.VarId inVarId = InferenceSession.lookup(n, opInputs, allIterInputs, false);
        Preconditions.checkState((inVarId != null ? 1 : 0) != 0, (String)"Could not find array for variable %s", (Object)sdv.name());
        return this.getTensorFromOutputs(inVarId);
    }

    public SessionMemMgr getMmgr() {
        return this.mmgr;
    }

    public void setMmgr(SessionMemMgr mmgr) {
        this.mmgr = mmgr;
    }

    public AbstractDependencyTracker<SDValue, Dep> getArrayUseTracker() {
        return this.arrayUseTracker;
    }

    public void setArrayUseTracker(AbstractDependencyTracker<SDValue, Dep> arrayUseTracker) {
        this.arrayUseTracker = arrayUseTracker;
    }

    protected static class ExecDoneDep
    extends Dep {
        @Override
        public String toString() {
            return "InferenceSession.ExecDoneDep()";
        }

        @Override
        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof ExecDoneDep)) {
                return false;
            }
            ExecDoneDep other = (ExecDoneDep)o;
            if (!other.canEqual(this)) {
                return false;
            }
            return super.equals(o);
        }

        @Override
        protected boolean canEqual(Object other) {
            return other instanceof ExecDoneDep;
        }

        @Override
        public int hashCode() {
            int result = super.hashCode();
            return result;
        }
    }

    protected static class ReqOutputDep
    extends Dep {
        protected String outputName;

        public String getOutputName() {
            return this.outputName;
        }

        public void setOutputName(String outputName) {
            this.outputName = outputName;
        }

        @Override
        public String toString() {
            return "InferenceSession.ReqOutputDep(outputName=" + this.getOutputName() + ")";
        }

        @Override
        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof ReqOutputDep)) {
                return false;
            }
            ReqOutputDep other = (ReqOutputDep)o;
            if (!other.canEqual(this)) {
                return false;
            }
            if (!super.equals(o)) {
                return false;
            }
            String this$outputName = this.getOutputName();
            String other$outputName = other.getOutputName();
            return !(this$outputName == null ? other$outputName != null : !this$outputName.equals(other$outputName));
        }

        @Override
        protected boolean canEqual(Object other) {
            return other instanceof ReqOutputDep;
        }

        @Override
        public int hashCode() {
            int PRIME = 59;
            int result = super.hashCode();
            String $outputName = this.getOutputName();
            result = result * 59 + ($outputName == null ? 43 : $outputName.hashCode());
            return result;
        }

        public ReqOutputDep(String outputName) {
            this.outputName = outputName;
        }
    }

    protected static class ConstantDep
    extends Dep {
        protected String constName;

        public String getConstName() {
            return this.constName;
        }

        public void setConstName(String constName) {
            this.constName = constName;
        }

        @Override
        public String toString() {
            return "InferenceSession.ConstantDep(constName=" + this.getConstName() + ")";
        }

        @Override
        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof ConstantDep)) {
                return false;
            }
            ConstantDep other = (ConstantDep)o;
            if (!other.canEqual(this)) {
                return false;
            }
            if (!super.equals(o)) {
                return false;
            }
            String this$constName = this.getConstName();
            String other$constName = other.getConstName();
            return !(this$constName == null ? other$constName != null : !this$constName.equals(other$constName));
        }

        @Override
        protected boolean canEqual(Object other) {
            return other instanceof ConstantDep;
        }

        @Override
        public int hashCode() {
            int PRIME = 59;
            int result = super.hashCode();
            String $constName = this.getConstName();
            result = result * 59 + ($constName == null ? 43 : $constName.hashCode());
            return result;
        }

        public ConstantDep(String constName) {
            this.constName = constName;
        }
    }

    protected static class VariableDep
    extends Dep {
        protected String varName;

        public String getVarName() {
            return this.varName;
        }

        public void setVarName(String varName) {
            this.varName = varName;
        }

        @Override
        public String toString() {
            return "InferenceSession.VariableDep(varName=" + this.getVarName() + ")";
        }

        @Override
        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof VariableDep)) {
                return false;
            }
            VariableDep other = (VariableDep)o;
            if (!other.canEqual(this)) {
                return false;
            }
            if (!super.equals(o)) {
                return false;
            }
            String this$varName = this.getVarName();
            String other$varName = other.getVarName();
            return !(this$varName == null ? other$varName != null : !this$varName.equals(other$varName));
        }

        @Override
        protected boolean canEqual(Object other) {
            return other instanceof VariableDep;
        }

        @Override
        public int hashCode() {
            int PRIME = 59;
            int result = super.hashCode();
            String $varName = this.getVarName();
            result = result * 59 + ($varName == null ? 43 : $varName.hashCode());
            return result;
        }

        public VariableDep(String varName) {
            this.varName = varName;
        }
    }

    protected static class PlaceholderDep
    extends Dep {
        protected String phName;

        public String getPhName() {
            return this.phName;
        }

        public void setPhName(String phName) {
            this.phName = phName;
        }

        @Override
        public String toString() {
            return "InferenceSession.PlaceholderDep(phName=" + this.getPhName() + ")";
        }

        @Override
        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof PlaceholderDep)) {
                return false;
            }
            PlaceholderDep other = (PlaceholderDep)o;
            if (!other.canEqual(this)) {
                return false;
            }
            if (!super.equals(o)) {
                return false;
            }
            String this$phName = this.getPhName();
            String other$phName = other.getPhName();
            return !(this$phName == null ? other$phName != null : !this$phName.equals(other$phName));
        }

        @Override
        protected boolean canEqual(Object other) {
            return other instanceof PlaceholderDep;
        }

        @Override
        public int hashCode() {
            int PRIME = 59;
            int result = super.hashCode();
            String $phName = this.getPhName();
            result = result * 59 + ($phName == null ? 43 : $phName.hashCode());
            return result;
        }

        public PlaceholderDep(String phName) {
            this.phName = phName;
        }
    }

    public static class OpDep
    extends Dep {
        protected String opName;
        protected int iter;

        protected OpDep(@NonNull String opName, @NonNull String frame, int iter, FrameIter parentFrame) {
            if (opName == null) {
                throw new NullPointerException("opName is marked non-null but is null");
            }
            if (frame == null) {
                throw new NullPointerException("frame is marked non-null but is null");
            }
            this.opName = opName;
            this.frame = frame;
            this.iter = iter;
            this.parentFrame = parentFrame;
        }

        @Override
        public String toString() {
            return "OpDep(" + this.opName + ",frame=" + this.frame + ",iter=" + this.iter + (String)(this.parentFrame == null ? "" : ",parent=" + this.parentFrame) + ")";
        }

        public OpDep(String opName, int iter) {
            this.opName = opName;
            this.iter = iter;
        }

        public String getOpName() {
            return this.opName;
        }

        public int getIter() {
            return this.iter;
        }

        public void setOpName(String opName) {
            this.opName = opName;
        }

        public void setIter(int iter) {
            this.iter = iter;
        }

        @Override
        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof OpDep)) {
                return false;
            }
            OpDep other = (OpDep)o;
            if (!other.canEqual(this)) {
                return false;
            }
            if (!super.equals(o)) {
                return false;
            }
            if (this.getIter() != other.getIter()) {
                return false;
            }
            String this$opName = this.getOpName();
            String other$opName = other.getOpName();
            return !(this$opName == null ? other$opName != null : !this$opName.equals(other$opName));
        }

        @Override
        protected boolean canEqual(Object other) {
            return other instanceof OpDep;
        }

        @Override
        public int hashCode() {
            int PRIME = 59;
            int result = super.hashCode();
            result = result * 59 + this.getIter();
            String $opName = this.getOpName();
            result = result * 59 + ($opName == null ? 43 : $opName.hashCode());
            return result;
        }
    }

    public static abstract class Dep {
        protected String frame;
        protected FrameIter parentFrame;

        public String getFrame() {
            return this.frame;
        }

        public FrameIter getParentFrame() {
            return this.parentFrame;
        }

        public void setFrame(String frame) {
            this.frame = frame;
        }

        public void setParentFrame(FrameIter parentFrame) {
            this.parentFrame = parentFrame;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof Dep)) {
                return false;
            }
            Dep other = (Dep)o;
            if (!other.canEqual(this)) {
                return false;
            }
            String this$frame = this.getFrame();
            String other$frame = other.getFrame();
            if (this$frame == null ? other$frame != null : !this$frame.equals(other$frame)) {
                return false;
            }
            FrameIter this$parentFrame = this.getParentFrame();
            FrameIter other$parentFrame = other.getParentFrame();
            return !(this$parentFrame == null ? other$parentFrame != null : !((Object)this$parentFrame).equals(other$parentFrame));
        }

        protected boolean canEqual(Object other) {
            return other instanceof Dep;
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            String $frame = this.getFrame();
            result = result * 59 + ($frame == null ? 43 : $frame.hashCode());
            FrameIter $parentFrame = this.getParentFrame();
            result = result * 59 + ($parentFrame == null ? 43 : ((Object)$parentFrame).hashCode());
            return result;
        }

        public String toString() {
            return "InferenceSession.Dep(frame=" + this.getFrame() + ", parentFrame=" + this.getParentFrame() + ")";
        }
    }
}

