/*
 * Decompiled with CFR 0.152.
 */
package meka.classifiers.multilabel;

import Jama.Matrix;
import java.util.Random;
import meka.classifiers.multilabel.MultilabelClassifier;
import meka.classifiers.multilabel.NN.AbstractNeuralNet;
import meka.core.M;
import meka.core.MLUtils;
import weka.core.Instance;
import weka.core.Instances;

public class BPNN
extends AbstractNeuralNet {
    public Matrix[] W = null;
    protected Random r = null;
    protected Matrix[] dW_ = null;

    public BPNN() {
        this.m_E = 100;
    }

    @Override
    public void buildClassifier(Instances D) throws Exception {
        this.testCapabilities(D);
        double[][] X_ = MLUtils.getXfromD(D);
        double[][] Y_ = MLUtils.getYfromD(D);
        int[] h = new int[]{this.m_H};
        this.init(X_, Y_, h);
        this.train(X_, Y_, this.m_E);
    }

    @Override
    public double[] distributionForInstance(Instance xy) throws Exception {
        double[] x = MLUtils.getxfromInstance(xy);
        return this.popy(x);
    }

    public void init(double[][] X_, double[][] Y_, Matrix[] W) throws Exception {
        this.r = new Random(0L);
        this.W = W;
        this.dW_ = new Matrix[W.length];
        for (int i = 0; i < this.dW_.length; ++i) {
            this.dW_[i] = new Matrix(W[i].getRowDimension(), W[i].getColumnDimension(), 0.0);
        }
    }

    public void init(double[][] X_, double[][] Y_, int[] H) throws Exception {
        int L = Y_[0].length;
        int d = X_[0].length;
        if (this.getDebug()) {
            System.out.println("Initializing " + H.length + " hidden Layers ...");
            System.out.println("d = " + d);
            System.out.println("L = " + L);
        }
        Matrix[] W = new Matrix[H.length + 1];
        for (int n = 0; n < H.length; ++n) {
            int h = H[n];
            if (this.getDebug()) {
                System.out.println("W[" + n + "] = " + (d + 1) + " x " + h);
            }
            W[n] = Matrix.random((int)(d + 1), (int)h).plusEquals(new Matrix(d + 1, h, -0.5)).timesEquals(0.1);
            d = h;
        }
        W[H.length] = Matrix.random((int)(d + 1), (int)L).plusEquals(new Matrix(d + 1, L, -0.5)).timesEquals(0.1);
        if (this.getDebug()) {
            System.out.println("W[" + H.length + "] = " + (d + 1) + " x " + L);
        }
        this.init(X_, Y_, W);
    }

    public double train(double[][] X_, double[][] Y_) throws Exception {
        return this.train(X_, Y_, this.m_E);
    }

    public double train(double[][] X_, double[][] Y_, int I) throws Exception {
        if (this.getDebug()) {
            System.out.println("BPNN train; For " + I + " epochs ...");
        }
        int N = X_.length;
        boolean breakEarly = I < 0;
        I = Math.abs(I);
        double E_ = Double.MAX_VALUE;
        double E = 0.0;
        for (int e = 0; e < I; ++e) {
            E = this.update(X_, Y_);
            if (breakEarly && E > E_) {
                if (!this.getDebug()) break;
                System.out.println(" early stopped at epcho " + e + " ... ");
                break;
            }
            E_ = E;
        }
        if (this.getDebug()) {
            System.out.println("Done.");
        }
        return E;
    }

    public double update(double[][] X_, double[][] Y_) throws Exception {
        int N = X_.length;
        double E = 0.0;
        for (int i = 0; i < N; ++i) {
            E += this.backPropagate(new double[][]{X_[i]}, new double[][]{Y_[i]});
        }
        return E;
    }

    public double[] popy(double[] x_) {
        return this.popY(new double[][]{x_})[0];
    }

    public double[][] popY(double[][] X_) {
        Matrix[] Z = this.forwardPass(X_);
        int n = Z.length - 1;
        return Z[n].getArray();
    }

    public Matrix[] forwardPass(double[][] X_) {
        int nW = this.W.length;
        Matrix[] Z = new Matrix[nW + 1];
        Z[0] = new Matrix(M.addBias(X_));
        for (int i = 1; i < Z.length; ++i) {
            Matrix A_z = Z[i - 1].times(this.W[i - 1]);
            Z[i] = M.sigma(A_z);
            Z[i] = M.addBias(Z[i]);
        }
        Matrix A_y = Z[nW - 1].times(this.W[nW - 1]);
        Z[nW] = M.sigma(A_y);
        return Z;
    }

    public double backPropagate(double[][] X_, double[][] Y_) throws Exception {
        int i;
        int N = X_.length;
        int L = Y_[0].length;
        int nW = this.W.length;
        Matrix T = new Matrix(Y_);
        Matrix[] Z = this.forwardPass(X_);
        Matrix[] dZ = new Matrix[nW + 1];
        Matrix E_y = T.minus(Z[nW]);
        dZ[nW] = M.dsigma(Z[nW]).arrayTimes(E_y);
        for (int i2 = nW - 1; i2 > 0; --i2) {
            Matrix E = dZ[i2 + 1].times(this.W[i2].transpose());
            dZ[i2] = M.dsigma(Z[i2]).arrayTimes(E);
            dZ[i2] = new Matrix(M.removeBias(dZ[i2].getArray()));
        }
        Matrix[] dW = new Matrix[nW];
        for (i = 0; i < nW; ++i) {
            dW[i] = Z[i].transpose().times(this.m_R).times(dZ[i + 1]).plus(this.dW_[i].times(this.m_M));
        }
        for (i = 0; i < nW; ++i) {
            this.W[i].plusEquals(dW[i]);
        }
        this.dW_ = dW;
        double SSE = E_y.normF();
        return SSE;
    }

    public static void main(String[] args) throws Exception {
        MultilabelClassifier.evaluation(new BPNN(), args);
    }
}

