/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.graph.vertex.impl;

import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex;
import org.deeplearning4j.nn.graph.vertex.VertexIndices;
import org.nd4j.linalg.api.ndarray.INDArray;

public class ElementWiseVertex
extends BaseGraphVertex {
    private Op op;
    private int nInForwardPass;

    public ElementWiseVertex(ComputationGraph graph, String name, int vertexIndex, Op op) {
        this(graph, name, vertexIndex, null, null, op);
    }

    public ElementWiseVertex(ComputationGraph graph, String name, int vertexIndex, VertexIndices[] inputVertices, VertexIndices[] outputVertices, Op op) {
        super(graph, name, vertexIndex, inputVertices, outputVertices);
        this.op = op;
    }

    @Override
    public boolean hasLayer() {
        return false;
    }

    @Override
    public boolean isOutputVertex() {
        return false;
    }

    @Override
    public Layer getLayer() {
        return null;
    }

    @Override
    public INDArray doForward(boolean training) {
        if (!this.canDoForward()) {
            throw new IllegalStateException("Cannot do forward pass: inputs not set");
        }
        this.nInForwardPass = this.inputs.length;
        if (this.inputs.length == 1) {
            return this.inputs[0];
        }
        switch (this.op) {
            case Add: {
                INDArray sum = this.inputs[0].dup();
                for (int i = 1; i < this.inputs.length; ++i) {
                    sum.addi(this.inputs[i]);
                }
                return sum;
            }
            case Subtract: {
                if (this.inputs.length != 2) {
                    throw new IllegalArgumentException("ElementWise subtraction only supports 2 inputs");
                }
                return this.inputs[0].sub(this.inputs[1]);
            }
            case Product: {
                throw new UnsupportedOperationException("ElementWise product: Not yet implemented");
            }
        }
        throw new UnsupportedOperationException("Unknown op: " + (Object)((Object)this.op));
    }

    @Override
    public Pair<Gradient, INDArray[]> doBackward(boolean tbptt) {
        if (!this.canDoBackward()) {
            throw new IllegalStateException("Cannot do backward pass: errors not set");
        }
        if (this.nInForwardPass == 1) {
            return new Pair<Object, INDArray[]>(null, this.epsilons);
        }
        switch (this.op) {
            case Add: {
                INDArray[] out = new INDArray[this.nInForwardPass];
                out[0] = this.epsilons[0];
                for (int i = 1; i < this.nInForwardPass; ++i) {
                    out[i] = out[0].dup();
                }
                return new Pair<Object, INDArray[]>(null, out);
            }
            case Subtract: {
                INDArray[] out2 = new INDArray[]{this.epsilons[0], this.epsilons[0].mul((Number)-1)};
                return new Pair<Object, INDArray[]>(null, out2);
            }
            case Product: {
                throw new UnsupportedOperationException("Not yet implemented");
            }
        }
        throw new UnsupportedOperationException("Unknown op: " + (Object)((Object)this.op));
    }

    @Override
    public void setBackpropGradientsViewArray(INDArray backpropGradientsViewArray) {
        if (backpropGradientsViewArray != null) {
            throw new RuntimeException("Vertex does not have gradients; gradients view array cannot be set here");
        }
    }

    @Override
    public String toString() {
        return "ElementWiseVertex(id=" + this.getVertexIndex() + ",name=\"" + this.getVertexName() + "\",op=" + (Object)((Object)this.op) + ")";
    }

    public static enum Op {
        Add,
        Subtract,
        Product;

    }
}

