/*
 * Decompiled with CFR 0.152.
 */
package jsat.datatransform;

import java.util.Random;
import jsat.DataSet;
import jsat.classifiers.DataPoint;
import jsat.datatransform.DataTransform;
import jsat.datatransform.DataTransformBase;
import jsat.distributions.Distribution;
import jsat.distributions.LogUniform;
import jsat.linear.DenseMatrix;
import jsat.linear.Matrix;
import jsat.linear.RandomMatrix;
import jsat.linear.Vec;
import jsat.utils.random.XORWOW;

public class JLTransform
extends DataTransformBase {
    private static final long serialVersionUID = -8621368067861343912L;
    private TransformMode mode;
    private Matrix R;
    private int k;
    private boolean inMemory;

    protected JLTransform(JLTransform transform) {
        this.mode = transform.mode;
        this.R = transform.R.clone();
    }

    public JLTransform() {
        this(50);
    }

    public JLTransform(int k) {
        this(k, TransformMode.SPARSE);
    }

    public JLTransform(int k, TransformMode mode) {
        this(k, mode, true);
    }

    public JLTransform(int k, TransformMode mode, boolean inMemory) {
        this.mode = mode;
        this.k = k;
        this.inMemory = inMemory;
    }

    @Override
    public void fit(DataSet data) {
        int d = data.getNumNumericalVars();
        XORWOW rand = new XORWOW();
        this.R = new RandomMatrixJL(this.k, d, ((Random)rand).nextLong(), this.mode);
        RandomMatrixJL oldR = this.R;
        if (this.inMemory) {
            this.R = new DenseMatrix(this.k, d);
            this.R.mutableAdd(oldR);
        }
    }

    public void setMode(TransformMode mode) {
        this.mode = mode;
    }

    public TransformMode getMode() {
        return this.mode;
    }

    public void setInMemory(boolean inMemory) {
        this.inMemory = inMemory;
    }

    public boolean isInMemory() {
        return this.inMemory;
    }

    public void setProjectedDimension(int k) {
        this.k = k;
    }

    public int getProjectedDimension() {
        return this.k;
    }

    public static Distribution guessProjectedDimension(DataSet d) {
        double max = 100.0;
        double min = 10.0;
        if (d.getNumNumericalVars() > 10000) {
            min = 100.0;
            max = 1000.0;
        }
        return new LogUniform(min, max);
    }

    @Override
    public DataPoint transform(DataPoint dp) {
        Vec newVec = dp.getNumericalValues();
        newVec = this.R.multiply(newVec);
        DataPoint newDP = new DataPoint(newVec, dp.getCategoricalValues(), dp.getCategoricalData(), dp.getWeight());
        return newDP;
    }

    @Override
    public DataTransform clone() {
        return new JLTransform(this);
    }

    private static class RandomMatrixJL
    extends RandomMatrix {
        private static final long serialVersionUID = 2009377824896155918L;
        private double cnst;
        private TransformMode mode;

        public RandomMatrixJL(int rows, int cols, long XORSeed, TransformMode mode) {
            super(rows, cols, XORSeed);
            this.mode = mode;
            int k = rows;
            if (mode == TransformMode.GAUSS || mode == TransformMode.BINARY) {
                this.cnst = 1.0 / Math.sqrt(k);
            } else if (mode == TransformMode.SPARSE) {
                this.cnst = Math.sqrt(3.0) / Math.sqrt(k);
            }
        }

        @Override
        protected double getVal(Random rand) {
            if (this.mode == TransformMode.GAUSS) {
                return rand.nextGaussian() * this.cnst;
            }
            if (this.mode == TransformMode.BINARY) {
                return rand.nextBoolean() ? -this.cnst : this.cnst;
            }
            if (this.mode == TransformMode.SPARSE) {
                int val = rand.nextInt(6);
                if (val == 0) {
                    return -this.cnst;
                }
                if (val == 1) {
                    return this.cnst;
                }
                return 0.0;
            }
            throw new RuntimeException("BUG: Please report");
        }
    }

    public static enum TransformMode {
        GAUSS,
        BINARY,
        SPARSE;

    }
}

