/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.weights;

import java.util.Arrays;
import org.apache.commons.math3.util.FastMath;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.api.rng.distribution.Distribution;
import org.nd4j.linalg.factory.Nd4j;

public class WeightInitUtil {
    public static final char DEFAULT_WEIGHT_INIT_ORDER = 'f';

    private WeightInitUtil() {
    }

    public static INDArray initWeights(int[] shape, float min, float max) {
        return Nd4j.rand((int[])shape, (double)min, (double)max, (Random)Nd4j.getRandom());
    }

    public static INDArray initWeights(double fanIn, double fanOut, int[] shape, WeightInit initScheme, Distribution dist, INDArray paramView) {
        return WeightInitUtil.initWeights(fanIn, fanOut, shape, initScheme, dist, 'f', paramView);
    }

    public static INDArray initWeights(double fanIn, double fanOut, int[] shape, WeightInit initScheme, Distribution dist, char order, INDArray paramView) {
        INDArray ret;
        switch (initScheme) {
            case DISTRIBUTION: {
                ret = dist.sample(shape);
                break;
            }
            case NORMALIZED: {
                ret = Nd4j.rand((char)order, (int[])shape);
                ret.subi((Number)0.5).divi((Number)shape[0]);
                break;
            }
            case RELU: {
                ret = Nd4j.randn((char)order, (int[])shape).muli((Number)FastMath.sqrt((double)(2.0 / fanIn)));
                break;
            }
            case RELU_UNIFORM: {
                double u = Math.sqrt(6.0 / fanIn);
                ret = Nd4j.rand((int[])shape, (Distribution)Nd4j.getDistributions().createUniform(-u, u));
                break;
            }
            case SIZE: 
            case SIGMOID_UNIFORM: {
                double r = 4.0 * Math.sqrt(6.0 / (fanIn + fanOut));
                ret = Nd4j.rand((int[])shape, (Distribution)Nd4j.getDistributions().createUniform(-r, r));
                break;
            }
            case UNIFORM: {
                double a = 1.0 / Math.sqrt(fanIn);
                ret = Nd4j.rand((int[])shape, (Distribution)Nd4j.getDistributions().createUniform(-a, a));
                break;
            }
            case XAVIER: {
                ret = Nd4j.randn((char)order, (int[])shape).muli((Number)FastMath.sqrt((double)(2.0 / (fanIn + fanOut))));
                break;
            }
            case VI: 
            case XAVIER_UNIFORM: {
                double s = Math.sqrt(6.0) / Math.sqrt(fanIn + fanOut);
                ret = Nd4j.rand((int[])shape, (Distribution)Nd4j.getDistributions().createUniform(-s, s));
                break;
            }
            case XAVIER_FAN_IN: {
                ret = Nd4j.randn((char)order, (int[])shape).divi((Number)FastMath.sqrt((double)fanIn));
                break;
            }
            case XAVIER_LEGACY: {
                ret = Nd4j.randn((char)order, (int[])shape).divi((Number)FastMath.sqrt((double)(shape[0] + shape[1])));
                break;
            }
            case ZERO: {
                ret = Nd4j.create((int[])shape, (char)order);
                break;
            }
            default: {
                throw new IllegalStateException("Illegal weight init value: " + (Object)((Object)initScheme));
            }
        }
        INDArray flat = Nd4j.toFlattened((char)order, (INDArray[])new INDArray[]{ret});
        if (flat.length() != paramView.length()) {
            throw new RuntimeException("ParamView length does not match initialized weights length (view length: " + paramView.length() + ", view shape: " + Arrays.toString(paramView.shape()) + "; flattened length: " + flat.length());
        }
        paramView.assign(flat);
        return paramView.reshape(order, shape);
    }

    public static INDArray reshapeWeights(int[] shape, INDArray paramsView) {
        return WeightInitUtil.reshapeWeights(shape, paramsView, 'f');
    }

    public static INDArray reshapeWeights(int[] shape, INDArray paramsView, char flatteningOrder) {
        return paramsView.reshape(flatteningOrder, shape);
    }
}

