/*
 * Decompiled with CFR 0.152.
 */
package rbms;

import Jama.Matrix;
import rbms.RBM;

public class DBM
extends RBM {
    protected RBM[] rbm = null;
    protected int[] h = null;

    public DBM(String[] options) throws Exception {
        super.setOptions(options);
    }

    public RBM[] getRBMs() {
        return this.rbm;
    }

    @Override
    public double[] prob_z(double[] z) {
        if (this.rbm == null) {
            return null;
        }
        for (int i = 0; i < this.h.length; ++i) {
            z = this.rbm[i].prob_z(z);
        }
        return z;
    }

    @Override
    public double[][] prob_Z(double[][] X_) {
        if (this.rbm == null) {
            return null;
        }
        for (int i = 0; i < this.h.length; ++i) {
            X_ = this.rbm[i].prob_Z(X_);
        }
        return X_;
    }

    public void setH(int[] h) {
        this.h = h;
    }

    public void setH(int H, int L, int N) {
        int[] h = new int[N];
        for (int i = 0; i < N - 1; ++i) {
            h[i] = H;
        }
        h[N - 1] = L;
        this.h = h;
    }

    public void setH(int H, int N) {
        int[] h = new int[N];
        for (int i = 0; i < N; ++i) {
            h[i] = H;
        }
        this.h = h;
    }

    @Override
    public void setH(int H) {
        this.setH(H, 2);
    }

    @Override
    public Matrix[] getWs() {
        Matrix[] W = new Matrix[this.rbm.length];
        for (int i = 0; i < W.length; ++i) {
            W[i] = this.rbm[i].getW();
        }
        return W;
    }

    @Override
    public double train(double[][] X_) throws Exception {
        return this.train(X_, 0);
    }

    @Override
    public double train(double[][] X_, int batchSize) throws Exception {
        int N = this.h.length;
        this.rbm = new RBM[N];
        for (int i = 0; i < N; ++i) {
            this.rbm[i] = new RBM(this.getOptions());
            this.rbm[i].setH(this.h[i]);
            if (batchSize == 0) {
                this.rbm[i].train(X_);
            } else {
                this.rbm[i].train(X_, batchSize);
            }
            X_ = this.rbm[i].prob_Z(X_);
        }
        return 1.0;
    }

    @Override
    public void update(Matrix X) {
        for (int i = 0; i < this.h.length; ++i) {
            this.rbm[i].update(X);
            try {
                X = this.rbm[i].prob_Z(X);
                continue;
            }
            catch (Exception e) {
                System.err.println("AHH!!");
                e.printStackTrace();
            }
        }
    }

    @Override
    public void update(Matrix X, double s) {
        for (int i = 0; i < this.h.length; ++i) {
            this.rbm[i].update(X, s);
            try {
                X = this.rbm[i].prob_Z(X);
                continue;
            }
            catch (Exception e) {
                System.err.println("AHH!!");
                e.printStackTrace();
            }
        }
    }

    @Override
    public void update(double[][] X_) {
        for (int i = 0; i < this.h.length; ++i) {
            this.rbm[i].update(X_);
            try {
                X_ = this.rbm[i].prob_Z(X_);
                continue;
            }
            catch (Exception e) {
                System.err.println("AHH!!");
                e.printStackTrace();
            }
        }
    }
}

