/*
 * Decompiled with CFR 0.152.
 */
package weka.filters.supervised.attribute;

import java.util.ArrayList;
import weka.core.Attribute;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.matrix.Matrix;
import weka.filters.Filter;
import weka.filters.supervised.attribute.PLSFilter;
import weka.filters.unsupervised.attribute.Center;
import weka.filters.unsupervised.attribute.Standardize;

public class PLSFilterExtended
extends PLSFilter {
    private static final long serialVersionUID = 3908648580492805218L;
    protected double[] m_Means = new double[0];
    protected double[] m_StdDevs = new double[0];

    public Matrix getxWeights() {
        return this.m_PLS1_W;
    }

    public Matrix getX(Instances instances) {
        int numberXattributes = instances.classIndex();
        double[][] x = new double[instances.numInstances()][numberXattributes];
        for (int i = 0; i < instances.numInstances(); ++i) {
            for (int j = 0; j < numberXattributes; ++j) {
                x[i][j] = instances.instance(i).value(j);
            }
        }
        Matrix result = new Matrix(x);
        return result;
    }

    public Matrix getY(Instances instances) {
        int numberYattributes = instances.numAttributes() - instances.classIndex();
        double[][] y = new double[instances.numInstances()][numberYattributes];
        for (int i = 0; i < instances.numInstances(); ++i) {
            for (int j = 0; j < numberYattributes; ++j) {
                y[i][j] = instances.instance(i).value(instances.classIndex() + j);
            }
        }
        Matrix result = new Matrix(y);
        return result;
    }

    public Matrix getbHat() {
        return this.m_PLS1_b_hat;
    }

    public Matrix getRegVector() {
        return this.m_PLS1_RegVector;
    }

    public double[] means() {
        return this.m_Means;
    }

    public double[] stdDevs() {
        return this.m_StdDevs;
    }

    protected Instances determineOutputFormat(Instances inputFormat) throws Exception {
        int i;
        int numberPredictedY = inputFormat.numAttributes() - inputFormat.classIndex();
        int numberAttributes = this.getNumComponents() + numberPredictedY;
        ArrayList<Attribute> atts = new ArrayList<Attribute>(numberAttributes);
        String prefix = this.getAlgorithm().getSelectedTag().getReadable();
        for (i = 0; i < this.getNumComponents(); ++i) {
            atts.add(new Attribute(prefix + "_" + (i + 1)));
        }
        for (i = 0; i < numberPredictedY; ++i) {
            atts.add(new Attribute("Class" + (i + 1)));
        }
        Instances result = new Instances(prefix, atts, 0);
        result.setClassIndex(this.getNumComponents());
        return result;
    }

    protected Instances toInstances(Instances header, Matrix x, Matrix y) {
        Instances result = new Instances(header, 0);
        int rows = x.getRowDimension();
        int colsX = x.getColumnDimension();
        int colsY = y.getColumnDimension();
        for (int i = 0; i < rows; ++i) {
            int k = -1;
            int l = -1;
            double[] values = new double[colsX + colsY];
            for (int n = 0; n < values.length; ++n) {
                values[n] = n < colsX ? x.get(i, ++k) : y.get(i, ++l);
            }
            result.add((Instance)new DenseInstance(1.0, values));
        }
        return result;
    }

    protected Instances process(Instances instances) throws Exception {
        Instances result = null;
        Object instancesInput = !this.getPerformPrediction() ? instances : null;
        if (!this.isFirstBatchDone()) {
            if (this.m_ReplaceMissing) {
                this.m_Missing.setInputFormat(instances);
            }
            int numberYattributes = instances.numAttributes() - instances.classIndex();
            this.m_Means = new double[numberYattributes];
            this.m_StdDevs = new double[numberYattributes];
            switch (this.m_Preprocessing) {
                case 1: {
                    int i;
                    for (i = 0; i < numberYattributes; ++i) {
                        this.m_Means[i] = instances.meanOrMode(instances.classIndex() + i);
                        this.m_StdDevs[i] = 1.0;
                    }
                    this.m_Filter = new Center();
                    ((Center)this.m_Filter).setIgnoreClass(true);
                    break;
                }
                case 2: {
                    int i;
                    for (i = 0; i < numberYattributes; ++i) {
                        this.m_Means[i] = instances.meanOrMode(instances.classIndex() + i);
                        this.m_StdDevs[i] = StrictMath.sqrt(instances.variance(instances.classIndex() + i));
                    }
                    this.m_Filter = new Standardize();
                    ((Standardize)this.m_Filter).setIgnoreClass(true);
                    break;
                }
                default: {
                    int i;
                    for (i = 0; i < numberYattributes; ++i) {
                        this.m_Means[i] = 0.0;
                        this.m_StdDevs[i] = 1.0;
                    }
                    this.m_Filter = null;
                }
            }
            if (this.m_Filter != null) {
                this.m_Filter.setInputFormat(instances);
            }
        }
        if (this.m_ReplaceMissing) {
            instances = Filter.useFilter((Instances)instances, (Filter)this.m_Missing);
        }
        if (this.m_Filter != null) {
            instances = Filter.useFilter((Instances)instances, (Filter)this.m_Filter);
        }
        switch (this.m_Algorithm) {
            case 1: {
                result = this.processSIMPLS(instances);
                break;
            }
            case 2: {
                result = this.processPLS1(instances);
                break;
            }
            default: {
                throw new IllegalStateException("Algorithm type '" + this.m_Algorithm + "' is not recognized!");
            }
        }
        for (int i = 0; i < result.numInstances(); ++i) {
            int numberYattributes = instances.numAttributes() - instances.classIndex();
            for (int j = 0; j < numberYattributes; ++j) {
                double value;
                if (!this.getPerformPrediction()) {
                    value = instancesInput.instance(i).value(instances.classIndex() + j);
                    result.instance(i).setValue(this.getNumComponents() + j, value);
                    continue;
                }
                value = result.instance(i).value(this.getNumComponents() + j);
                result.instance(i).setValue(this.getNumComponents() + j, value * this.m_StdDevs[j] + this.m_Means[j]);
            }
        }
        return result;
    }

    protected Instances processSIMPLS(Instances instances) throws Exception {
        Instances result;
        if (!this.isFirstBatchDone()) {
            Matrix T;
            Matrix X = this.getX(instances);
            Matrix X_trans = X.transpose();
            Matrix Y = this.getY(instances);
            Matrix A = X_trans.times(Y);
            Matrix M = X_trans.times(X);
            Matrix C = Matrix.identity((int)X.getColumnDimension(), (int)X.getColumnDimension());
            Matrix W = new Matrix(X.getColumnDimension(), this.getNumComponents());
            Matrix P = new Matrix(X.getColumnDimension(), this.getNumComponents());
            Matrix Q = new Matrix(Y.getColumnDimension(), this.getNumComponents());
            for (int h = 0; h < this.getNumComponents(); ++h) {
                Matrix A_trans = A.transpose();
                Matrix q = this.getDominantEigenVector(A_trans.times(A));
                Matrix w = A.times(q);
                Matrix c = w.transpose().times(M).times(w);
                w = w.times(1.0 / StrictMath.sqrt(c.get(0, 0)));
                this.setVector(w, W, h);
                Matrix p = M.times(w);
                Matrix p_trans = p.transpose();
                this.setVector(p, P, h);
                q = A_trans.times(w);
                this.setVector(q, Q, h);
                Matrix v = C.times(p);
                this.normalizeVector(v);
                Matrix v_trans = v.transpose();
                C = C.minus(v.times(v_trans));
                M = M.minus(p.times(p_trans));
                A = C.times(A);
            }
            this.m_SIMPLS_W = W;
            Matrix X_new = T = X.times(this.m_SIMPLS_W);
            this.m_SIMPLS_B = W.times(Q.transpose());
            Matrix y = this.getPerformPrediction() ? T.times(P.transpose()).times(this.m_SIMPLS_B) : this.getY(instances);
            result = this.toInstances(this.getOutputFormat(), X_new, y);
        } else {
            result = new Instances(this.getOutputFormat());
            Matrix X = this.getX(instances);
            Matrix X_new = X.times(this.m_SIMPLS_W);
            Matrix y = this.getPerformPrediction() ? X.times(this.m_SIMPLS_B) : this.getY(instances);
            result = this.toInstances(this.getOutputFormat(), X_new, y);
        }
        return result;
    }
}

