/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.params;

import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.distribution.Distributions;
import org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM;
import org.deeplearning4j.nn.weights.WeightInitUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.distribution.Distribution;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

public class GravesBidirectionalLSTMParamInitializer
implements ParamInitializer {
    private static final GravesBidirectionalLSTMParamInitializer INSTANCE = new GravesBidirectionalLSTMParamInitializer();
    public static final String RECURRENT_WEIGHT_KEY_FORWARDS = "RWF";
    public static final String BIAS_KEY_FORWARDS = "bF";
    public static final String INPUT_WEIGHT_KEY_FORWARDS = "WF";
    public static final String RECURRENT_WEIGHT_KEY_BACKWARDS = "RWB";
    public static final String BIAS_KEY_BACKWARDS = "bB";
    public static final String INPUT_WEIGHT_KEY_BACKWARDS = "WB";

    public static GravesBidirectionalLSTMParamInitializer getInstance() {
        return INSTANCE;
    }

    @Override
    public int numParams(NeuralNetConfiguration conf, boolean backprop) {
        GravesBidirectionalLSTM layerConf = (GravesBidirectionalLSTM)conf.getLayer();
        int nL = layerConf.getNOut();
        int nLast = layerConf.getNIn();
        int nParamsForward = nLast * (4 * nL) + nL * (4 * nL + 3) + 4 * nL;
        return 2 * nParamsForward;
    }

    @Override
    public Map<String, INDArray> init(NeuralNetConfiguration conf, INDArray paramsView, boolean initializeParams) {
        Map<String, INDArray> params = Collections.synchronizedMap(new LinkedHashMap());
        GravesBidirectionalLSTM layerConf = (GravesBidirectionalLSTM)conf.getLayer();
        double forgetGateInit = layerConf.getForgetGateBiasInit();
        Distribution dist = Distributions.createDistribution(layerConf.getDist());
        int nL = layerConf.getNOut();
        int nLast = layerConf.getNIn();
        conf.addVariable(INPUT_WEIGHT_KEY_FORWARDS);
        conf.addVariable(RECURRENT_WEIGHT_KEY_FORWARDS);
        conf.addVariable(BIAS_KEY_FORWARDS);
        conf.addVariable(INPUT_WEIGHT_KEY_BACKWARDS);
        conf.addVariable(RECURRENT_WEIGHT_KEY_BACKWARDS);
        conf.addVariable(BIAS_KEY_BACKWARDS);
        int nParamsInput = nLast * (4 * nL);
        int nParamsRecurrent = nL * (4 * nL + 3);
        int nBias = 4 * nL;
        int rwFOffset = nParamsInput;
        int bFOffset = rwFOffset + nParamsRecurrent;
        int iwROffset = bFOffset + nBias;
        int rwROffset = iwROffset + nParamsInput;
        int bROffset = rwROffset + nParamsRecurrent;
        INDArray iwF = paramsView.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)0, (int)rwFOffset)});
        INDArray rwF = paramsView.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)rwFOffset, (int)bFOffset)});
        INDArray bF = paramsView.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)bFOffset, (int)iwROffset)});
        INDArray iwR = paramsView.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)iwROffset, (int)rwROffset)});
        INDArray rwR = paramsView.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)rwROffset, (int)bROffset)});
        INDArray bR = paramsView.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)bROffset, (int)(bROffset + nBias))});
        if (initializeParams) {
            bF.put(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)nL, (int)(2 * nL))}, Nd4j.ones((int)1, (int)nL).muli((Number)forgetGateInit));
            bR.put(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)nL, (int)(2 * nL))}, Nd4j.ones((int)1, (int)nL).muli((Number)forgetGateInit));
        }
        if (initializeParams) {
            params.put(INPUT_WEIGHT_KEY_FORWARDS, WeightInitUtil.initWeights(nLast, 4 * nL, layerConf.getWeightInit(), dist, iwF));
            params.put(RECURRENT_WEIGHT_KEY_FORWARDS, WeightInitUtil.initWeights(nL, 4 * nL + 3, layerConf.getWeightInit(), dist, rwF));
            params.put(BIAS_KEY_FORWARDS, bF);
            params.put(INPUT_WEIGHT_KEY_BACKWARDS, WeightInitUtil.initWeights(nLast, 4 * nL, layerConf.getWeightInit(), dist, iwR));
            params.put(RECURRENT_WEIGHT_KEY_BACKWARDS, WeightInitUtil.initWeights(nL, 4 * nL + 3, layerConf.getWeightInit(), dist, rwR));
            params.put(BIAS_KEY_BACKWARDS, bR);
        } else {
            params.put(INPUT_WEIGHT_KEY_FORWARDS, WeightInitUtil.reshapeWeights(new int[]{nLast, 4 * nL}, iwF));
            params.put(RECURRENT_WEIGHT_KEY_FORWARDS, WeightInitUtil.reshapeWeights(new int[]{nL, 4 * nL + 3}, rwF));
            params.put(BIAS_KEY_FORWARDS, bF);
            params.put(INPUT_WEIGHT_KEY_BACKWARDS, WeightInitUtil.reshapeWeights(new int[]{nLast, 4 * nL}, iwR));
            params.put(RECURRENT_WEIGHT_KEY_BACKWARDS, WeightInitUtil.reshapeWeights(new int[]{nL, 4 * nL + 3}, rwR));
            params.put(BIAS_KEY_BACKWARDS, bR);
        }
        return params;
    }

    @Override
    public Map<String, INDArray> getGradientsFromFlattened(NeuralNetConfiguration conf, INDArray gradientView) {
        GravesBidirectionalLSTM layerConf = (GravesBidirectionalLSTM)conf.getLayer();
        int nL = layerConf.getNOut();
        int nLast = layerConf.getNIn();
        int nParamsInput = nLast * (4 * nL);
        int nParamsRecurrent = nL * (4 * nL + 3);
        int nBias = 4 * nL;
        int rwFOffset = nParamsInput;
        int bFOffset = rwFOffset + nParamsRecurrent;
        int iwROffset = bFOffset + nBias;
        int rwROffset = iwROffset + nParamsInput;
        int bROffset = rwROffset + nParamsRecurrent;
        INDArray iwFG = gradientView.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)0, (int)rwFOffset)}).reshape('f', nLast, 4 * nL);
        INDArray rwFG = gradientView.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)rwFOffset, (int)bFOffset)}).reshape('f', nL, 4 * nL + 3);
        INDArray bFG = gradientView.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)bFOffset, (int)iwROffset)});
        INDArray iwRG = gradientView.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)iwROffset, (int)rwROffset)}).reshape('f', nLast, 4 * nL);
        INDArray rwRG = gradientView.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)rwROffset, (int)bROffset)}).reshape('f', nL, 4 * nL + 3);
        INDArray bRG = gradientView.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)bROffset, (int)(bROffset + nBias))});
        LinkedHashMap<String, INDArray> out = new LinkedHashMap<String, INDArray>();
        out.put(INPUT_WEIGHT_KEY_FORWARDS, iwFG);
        out.put(RECURRENT_WEIGHT_KEY_FORWARDS, rwFG);
        out.put(BIAS_KEY_FORWARDS, bFG);
        out.put(INPUT_WEIGHT_KEY_BACKWARDS, iwRG);
        out.put(RECURRENT_WEIGHT_KEY_BACKWARDS, rwRG);
        out.put(BIAS_KEY_BACKWARDS, bRG);
        return out;
    }
}

