/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.activations.impl;

import java.util.Arrays;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Nd4j;

public class ActivationPReLU
extends BaseActivationFunction {
    private INDArray alpha;
    private long[] sharedAxes = null;

    public ActivationPReLU(INDArray alpha, long[] sharedAxes) {
        this.alpha = alpha;
        this.sharedAxes = sharedAxes;
    }

    @Override
    public INDArray getActivation(INDArray in, boolean training) {
        DynamicCustomOp.DynamicCustomOpsBuilder prelu = DynamicCustomOp.builder("prelu").addOutputs(in).addInputs(in, this.alpha);
        if (this.sharedAxes != null) {
            for (long axis : this.sharedAxes) {
                prelu.addIntegerArguments(axis);
            }
        }
        Nd4j.getExecutioner().execAndReturn(prelu.build());
        return in;
    }

    @Override
    public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
        this.assertShape(in, epsilon);
        INDArray dLdalpha = this.alpha.ulike();
        INDArray outTemp = in.ulike();
        DynamicCustomOp.DynamicCustomOpsBuilder preluBp = DynamicCustomOp.builder("prelu_bp").addInputs(in, this.alpha, epsilon).addOutputs(outTemp, dLdalpha);
        if (this.sharedAxes != null) {
            for (long axis : this.sharedAxes) {
                preluBp.addIntegerArguments(axis);
            }
        }
        Nd4j.exec(preluBp.build());
        in.assign(outTemp);
        return new Pair((Object)in, (Object)dLdalpha);
    }

    public String toString() {
        return "prelu";
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof ActivationPReLU)) {
            return false;
        }
        ActivationPReLU other = (ActivationPReLU)o;
        if (!other.canEqual(this)) {
            return false;
        }
        INDArray this$alpha = this.getAlpha();
        INDArray other$alpha = other.getAlpha();
        if (this$alpha == null ? other$alpha != null : !this$alpha.equals(other$alpha)) {
            return false;
        }
        return Arrays.equals(this.getSharedAxes(), other.getSharedAxes());
    }

    protected boolean canEqual(Object other) {
        return other instanceof ActivationPReLU;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        INDArray $alpha = this.getAlpha();
        result = result * 59 + ($alpha == null ? 43 : $alpha.hashCode());
        result = result * 59 + Arrays.hashCode(this.getSharedAxes());
        return result;
    }

    public INDArray getAlpha() {
        return this.alpha;
    }

    public long[] getSharedAxes() {
        return this.sharedAxes;
    }
}

