/*
 * Decompiled with CFR 0.152.
 */
package jsat.datatransform;

import java.util.ArrayList;
import java.util.Arrays;
import jsat.DataSet;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.DataPoint;
import jsat.datatransform.DataTransform;
import jsat.linear.DenseMatrix;
import jsat.linear.DenseVector;
import jsat.linear.Matrix;
import jsat.linear.Vec;

public class PCA
implements DataTransform {
    private static final long serialVersionUID = 8736609877239941617L;
    private Matrix P;
    private int maxPCs;
    private double threshold;

    public PCA() {
        this(50);
    }

    public PCA(DataSet dataSet) {
        this(dataSet, Integer.MAX_VALUE);
    }

    public PCA(DataSet dataSet, int maxPCs) {
        this(dataSet, maxPCs, 1.0E-4);
    }

    public PCA(int maxPCs) {
        this(maxPCs, 1.0E-4);
    }

    public PCA(int maxPCs, double threshold) {
        this.setMaxPCs(maxPCs);
        this.setThreshold(threshold);
    }

    public PCA(DataSet dataSet, int maxPCs, double threshold) {
        this(maxPCs, threshold);
        this.fit(dataSet);
    }

    @Override
    public void fit(DataSet dataSet) {
        int i;
        ArrayList<DenseVector> scores = new ArrayList<DenseVector>();
        ArrayList<DenseVector> loadings = new ArrayList<DenseVector>();
        Matrix E = dataSet.getDataMatrix();
        int PCs = Math.min(dataSet.getSampleSize(), dataSet.getNumNumericalVars());
        PCs = Math.min(this.maxPCs, PCs);
        Vec t = PCA.getColumn(E);
        double tauOld = t.dot(t);
        DenseVector p = new DenseVector(E.cols());
        for (i = 1; i <= PCs; ++i) {
            for (int iter = 0; iter < 100; ++iter) {
                p.zeroOut();
                E.transposeMultiply(1.0, t, p);
                ((Vec)p).mutableDivide(tauOld);
                ((Vec)p).mutableMultiply(Math.pow(((Vec)p).dot(p), -0.5));
                t = E.multiply(p);
                t.mutableDivide(((Vec)p).dot(p));
                double tauNew = t.dot(t);
                if (iter > 0 && Math.abs(tauNew - tauOld) <= this.threshold * tauNew || iter == 99) {
                    scores.add(new DenseVector(t));
                    loadings.add(new DenseVector(p));
                    break;
                }
                tauOld = tauNew;
            }
            Matrix.OuterProductUpdate(E, t, p, -1.0);
        }
        this.P = new DenseMatrix(loadings.size(), ((Vec)loadings.get(0)).length());
        for (i = 0; i < loadings.size(); ++i) {
            Vec pi = (Vec)loadings.get(i);
            for (int j = 0; j < pi.length(); ++j) {
                this.P.set(i, j, pi.get(j));
            }
        }
    }

    private PCA(PCA other) {
        if (other.P != null) {
            this.P = other.P.clone();
        }
        this.maxPCs = other.maxPCs;
        this.threshold = other.threshold;
    }

    public void setMaxPCs(int maxPCs) {
        if (maxPCs <= 0) {
            throw new IllegalArgumentException("number of principal components must be a positive number, not " + maxPCs);
        }
        this.maxPCs = maxPCs;
    }

    public int getMaxPCs() {
        return this.maxPCs;
    }

    public void setThreshold(double threshold) {
        if (threshold <= 0.0 || Double.isInfinite(threshold) || Double.isNaN(threshold)) {
            throw new IllegalArgumentException("threshold must be in the range (0, Inf), not " + threshold);
        }
        this.threshold = threshold;
    }

    public double getThreshold() {
        return this.threshold;
    }

    private static Vec getColumn(Matrix x) {
        for (int i = 0; i < x.cols(); ++i) {
            Vec t = x.getColumn(i);
            if (!(t.dot(t) > 0.0)) continue;
            return t;
        }
        throw new ArithmeticException("Matrix is essentially zero");
    }

    @Override
    public DataPoint transform(DataPoint dp) {
        DataPoint newDP = new DataPoint(this.P.multiply(dp.getNumericalValues()), Arrays.copyOf(dp.getCategoricalValues(), dp.numCategoricalValues()), CategoricalData.copyOf(dp.getCategoricalData()), dp.getWeight());
        return newDP;
    }

    @Override
    public DataTransform clone() {
        return new PCA(this);
    }
}

