/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.ops.impl.shape;

import java.util.Arrays;
import java.util.List;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;

public class Cross
extends DynamicCustomOp {
    public Cross() {
    }

    public Cross(SameDiff sameDiff, SDVariable[] args) {
        super(null, sameDiff, args, false);
    }

    public Cross(INDArray a, INDArray b, INDArray out) {
        super(null, new INDArray[]{a, b}, new INDArray[]{out}, null, (int[])null);
    }

    @Override
    public String opName() {
        return "cross";
    }

    @Override
    public String tensorflowName() {
        return "Cross";
    }

    @Override
    public List<SDVariable> doDiff(List<SDVariable> gradients) {
        SDVariable grad = gradients.get(0);
        SDVariable a = this.larg();
        SDVariable b = this.rarg();
        SDVariable ones = this.sameDiff.onesLike(a);
        SDVariable gradLeft = grad.mul(this.sameDiff.cross(b, ones));
        SDVariable gradRight = grad.mul(this.sameDiff.cross(ones, a));
        return Arrays.asList(gradLeft, gradRight);
    }
}

