/*
 * Decompiled with CFR 0.152.
 */
package rbms;

import Jama.Matrix;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Random;
import rbms.Mat;
import weka.core.Utils;

public class RBM {
    protected double LEARNING_RATE = 0.1;
    protected double MOMENTUM = 0.1;
    protected double COST = 2.0E-4 * this.LEARNING_RATE;
    protected int m_E = 1000;
    protected int m_H = 10;
    private boolean m_V = false;
    private int batch_size = 0;
    protected Matrix W = null;
    protected Matrix dW_ = null;
    protected Random m_R = new Random(0L);

    public RBM() {
    }

    public RBM(String[] options) throws Exception {
        this.setOptions(options);
    }

    public void setOptions(String[] options) throws Exception {
        try {
            this.setH(Integer.parseInt(Utils.getOption((char)'H', (String[])options)));
            this.setE(Integer.parseInt(Utils.getOption((char)'E', (String[])options)));
            this.setLearningRate(Double.parseDouble(Utils.getOption((char)'r', (String[])options)));
            this.setMomentum(Double.parseDouble(Utils.getOption((char)'m', (String[])options)));
        }
        catch (Exception e) {
            System.err.println("Missing option!");
            e.printStackTrace();
            System.exit(1);
        }
    }

    public String[] getOptions() throws Exception {
        ArrayList<String> result = new ArrayList<String>();
        result.add("-r");
        result.add(String.valueOf(this.LEARNING_RATE));
        result.add("-m");
        result.add(String.valueOf(this.MOMENTUM));
        result.add("-E");
        result.add(String.valueOf(this.getE()));
        result.add("-H");
        result.add(String.valueOf(this.getH()));
        return result.toArray(new String[result.size()]);
    }

    public double[] prob_z(double[] x_) {
        Matrix x = new Matrix(Mat.addBias(x_), 1);
        double[] z = Mat.sigma(x.times(this.W).getArray()[0]);
        return Mat.removeBias(z);
    }

    public double[][] prob_Z(double[][] X_) {
        Matrix X = new Matrix(Mat.addBias(X_));
        return Mat.removeBias(this.prob_Z(X).getArray());
    }

    public Matrix prob_Z(Matrix X) {
        Matrix P_Z = Mat.sigma(X.times(this.W));
        Mat.fillCol(P_Z.getArray(), 0, 1.0);
        return P_Z;
    }

    public double[][] propUp(double[][] X_) {
        return Mat.threshold(this.prob_Z(X_), 0.5);
    }

    public Matrix sample_Z(Matrix X) {
        Matrix P_Z = this.prob_Z(X);
        return Mat.sample(P_Z, this.m_R);
    }

    public double[] sample_z(double[] x_) {
        double[] p = this.prob_z(x_);
        return Mat.sample(p, this.m_R);
    }

    public double[] sample_x(double[] z_) {
        double[] p_x = this.prob_x(z_);
        return Mat.sample(p_x, this.m_R);
    }

    public Matrix sample_X(Matrix Z) {
        Matrix P_X = this.prob_X(Z);
        return Mat.sample(P_X, this.m_R);
    }

    public double[] prob_x(double[] z_) {
        Matrix z = new Matrix(Mat.addBias(z_), 1);
        double[] x = Mat.sigma(z.times(this.W.transpose()).getArray()[0]);
        return Mat.removeBias(x);
    }

    public Matrix prob_X(Matrix Z) {
        Matrix X = new Matrix(Mat.sigma(Z.times(this.W.transpose()).getArray()));
        Mat.fillCol(X.getArray(), 0, 1.0);
        return X;
    }

    public static Matrix makeW(int d, int h, Random r) {
        double[][] W_ = Mat.multiply(Mat.randn(d + 1, h + 1, r), 0.2);
        Mat.fillRow(W_, 0, 0.0);
        Mat.fillCol(W_, 0, 0.0);
        return new Matrix(W_);
    }

    protected Matrix makeW(int d, int h) {
        return RBM.makeW(d, h, this.m_R);
    }

    private void initWeights(double[][] X_) {
        this.initWeights(X_[0].length, this.m_H);
    }

    private void initWeights(int d, int h) {
        this.W = this.makeW(d, h);
        this.dW_ = new Matrix(this.W.getRowDimension(), this.W.getColumnDimension());
    }

    public void initWeights(int d) {
        this.initWeights(d, this.m_H);
    }

    public void update(Matrix X) {
        Matrix CD = this.epoch(X);
        Matrix dW = CD.minusEquals(this.W.times(this.COST)).timesEquals(this.LEARNING_RATE);
        this.W.plusEquals(dW.plus(this.dW_.timesEquals(this.MOMENTUM)));
        this.dW_ = dW;
    }

    public void update(Matrix X, double s) {
        Matrix CD = this.epoch(X);
        Matrix dW = CD.minusEquals(this.W.times(this.COST)).timesEquals(this.LEARNING_RATE);
        dW = dW.times(s);
        this.W.plusEquals(dW.plus(this.dW_.timesEquals(this.MOMENTUM)));
        this.dW_ = dW;
    }

    public void update(double[][] X_) {
        Matrix X = new Matrix(Mat.addBias(X_));
        this.update(X);
    }

    public void update(double[] x_) {
        this.update(new double[][]{x_});
    }

    public void update(double[] x_, double s) {
        Matrix X = new Matrix(Mat.addBias(new double[][]{x_}));
        this.update(X, s);
    }

    public double train(double[][] X_) throws Exception {
        this.initWeights(X_);
        Matrix X = new Matrix(Mat.addBias(X_));
        double _error = Double.MAX_VALUE;
        for (int e = 0; e < this.m_E; ++e) {
            if (this.m_V) {
                double err_now = this.calculateError(X);
                if (_error < err_now) {
                    System.out.println("broken out @" + e);
                    break;
                }
                _error = err_now;
            }
            this.update(X);
        }
        return _error;
    }

    public double train(double[][] X_, int batchSize) throws Exception {
        this.initWeights(X_);
        X_ = Mat.addBias(X_);
        int N = X_.length;
        if (batchSize == N) {
            return this.train(X_);
        }
        int N_n = (int)Math.ceil((double)N * 1.0 / (double)batchSize);
        Matrix[] X_n = new Matrix[N_n];
        int n = 0;
        int i = 0;
        while (n < N) {
            X_n[i] = new Matrix((double[][])Arrays.copyOfRange(X_, n, Math.min(n + batchSize, N)));
            n += batchSize;
            ++i;
        }
        for (int e = 0; e < this.m_E; ++e) {
            for (Matrix X : X_n) {
                this.update(X, 1.0 / (double)N_n);
            }
        }
        return 1.0;
    }

    public double train(double[][] X_, int batchSize, Random r) throws Exception {
        this.initWeights(X_);
        X_ = Mat.addBias(X_);
        int N = X_.length;
        int N_n = (int)Math.ceil((double)N * 1.0 / (double)batchSize);
        Matrix[] X_n = new Matrix[N_n];
        int n = 0;
        int i = 0;
        while (n < N) {
            X_n[i] = new Matrix((double[][])Arrays.copyOfRange(X_, n, Math.min(n + batchSize, N)));
            n += batchSize;
            ++i;
        }
        for (int e = 0; e < this.m_E; ++e) {
            for (i = 0; i < N_n; ++i) {
                this.update(X_n[r.nextInt(N_n)]);
            }
        }
        return 1.0;
    }

    public double calculateError(Matrix X) {
        Matrix Z_up = this.prob_Z(X);
        Matrix X_down = this.prob_X(Z_up);
        return Mat.meanSquaredError(X.getArray(), X_down.getArray());
    }

    public Matrix epoch(Matrix X_0) {
        int N = X_0.getArray().length;
        Matrix Z_0 = this.prob_Z(X_0);
        Matrix E_pos = X_0.transpose().times(Z_0);
        Matrix X_1 = this.prob_X(Z_0);
        Matrix pZ_1 = this.prob_Z(X_1);
        Matrix E_neg = X_1.transpose().times(pZ_1);
        double _Err = Mat.meanSquaredError(X_0.getArray(), X_1.getArray());
        System.out.println("" + _Err);
        Matrix CD = E_pos.minusEquals(E_neg).times(1.0 / (double)N);
        return CD;
    }

    public Matrix sample_epoch(Matrix X_0) {
        int N = X_0.getArray().length;
        Matrix Z_0 = this.sample_Z(X_0);
        Matrix E_pos = X_0.transpose().times(Z_0);
        Matrix X_1 = this.sample_X(Z_0);
        Matrix pZ_1 = this.prob_Z(X_1);
        Matrix E_neg = X_1.transpose().times(pZ_1);
        double _Err = Mat.meanSquaredError(X_0.getArray(), X_1.getArray());
        System.out.println("" + _Err);
        Matrix CD = E_pos.minusEquals(E_neg).times(1.0 / (double)N);
        return CD;
    }

    public void setH(int h) {
        this.m_H = h;
    }

    public int getH() {
        return this.m_H;
    }

    public void setE(int n) {
        if (n < 0) {
            this.m_V = true;
            this.m_E = -n;
        } else {
            this.m_E = n;
        }
    }

    public int getE() {
        return this.m_E;
    }

    public void setLearningRate(double r) {
        this.LEARNING_RATE = r;
        this.COST = 2.0E-4 * this.LEARNING_RATE;
    }

    public double getLearningRate() {
        return this.LEARNING_RATE;
    }

    public void setMomentum(double m) {
        this.MOMENTUM = m;
    }

    public double getMomentum() {
        return this.MOMENTUM;
    }

    public void setSeed(int seed) {
        this.m_R = new Random(seed);
    }

    public Matrix[] getWs() {
        return new Matrix[]{this.W};
    }

    public Matrix getW() {
        return this.W;
    }

    public String toString() {
        Matrix W = this.getW();
        return Mat.toString(W);
    }

    public static void main(String[] argv) throws Exception {
    }
}

