/*
 * Decompiled with CFR 0.152.
 */
package org.ujmp.core.doublematrix.calculation.general.missingvalues;

import java.util.ArrayList;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import org.ujmp.core.Matrix;
import org.ujmp.core.MatrixFactory;
import org.ujmp.core.calculation.Calculation;
import org.ujmp.core.doublematrix.calculation.AbstractDoubleCalculation;
import org.ujmp.core.doublematrix.calculation.general.missingvalues.Impute;
import org.ujmp.core.exceptions.MatrixException;
import org.ujmp.core.util.MathUtil;

public class ImputeEM
extends AbstractDoubleCalculation {
    private static final long serialVersionUID = -1272010036598212696L;
    private Matrix bestGuess = null;
    private Matrix imputed = null;
    private double delta = 1.0E-6;
    private final double decay = 0.66;

    public ImputeEM(Matrix matrix) {
        super(matrix);
    }

    public ImputeEM(Matrix matrix, Matrix firstGuess) {
        super(matrix);
        this.bestGuess = firstGuess;
    }

    public ImputeEM(Matrix matrix, Matrix firstGuess, double delta) {
        super(matrix);
        this.bestGuess = firstGuess;
        this.delta = delta;
    }

    public double getDouble(long ... coordinates) throws MatrixException {
        double v;
        if (this.imputed == null) {
            this.createMatrix();
        }
        if (MathUtil.isNaNOrInfinite(v = this.getSource().getAsDouble(coordinates))) {
            return this.imputed.getAsDouble(coordinates);
        }
        return v;
    }

    private void createMatrix() {
        try {
            double d;
            ExecutorService executor = Executors.newFixedThreadPool(1);
            Matrix x = this.getSource();
            double valueCount = x.getValueCount();
            long missingCount = (long)x.countMissing(Calculation.Ret.NEW, Integer.MAX_VALUE).getEuklideanValue();
            double percent = (double)((int)Math.round((double)missingCount * 1000.0 / valueCount)) / 10.0;
            System.out.println("missing values: " + missingCount + " (" + percent + "%)");
            System.out.println("============================================");
            if (this.bestGuess == null) {
                this.bestGuess = this.getSource().impute(Calculation.Ret.NEW, Impute.ImputationMethod.RowMean, new Object[0]);
            }
            int run = 0;
            do {
                System.out.println("Iteration " + run++);
                ArrayList<Future<Long>> futures = new ArrayList<Future<Long>>();
                this.imputed = MatrixFactory.zeros(x.getSize());
                long t0 = System.currentTimeMillis();
                long l = 0L;
                while (l < x.getColumnCount()) {
                    futures.add(executor.submit(new PredictColumn(l)));
                    ++l;
                }
                for (Future future : futures) {
                    Long completedCols = (Long)future.get();
                    long elapsedTime = System.currentTimeMillis() - t0;
                    long remainingCols = x.getColumnCount() - completedCols;
                    double colsPerMillisecond = (double)(completedCols + 1L) / (double)elapsedTime;
                    long remainingTime = (long)((double)remainingCols / colsPerMillisecond / 1000.0);
                    System.out.println(String.valueOf((double)(completedCols * 1000L / x.getColumnCount()) / 10.0) + "% completed (" + remainingTime + " seconds remaining)");
                }
                d = this.imputed.euklideanDistanceTo(this.bestGuess, true) / (double)missingCount;
                System.out.println("delta: " + d);
                System.out.println("============================================");
                this.bestGuess = this.bestGuess.times(0.66).plus(this.imputed.times(0.33999999999999997));
            } while (!(d < this.delta));
            executor.shutdown();
            this.imputed = this.bestGuess;
            if (this.imputed.containsMissingValues()) {
                throw new MatrixException("Matrix has still missing values after imputation");
            }
        }
        catch (Exception e) {
            throw new MatrixException(e);
        }
    }

    private static Matrix replaceInColumn(Matrix original, Matrix firstGuess, long column) throws MatrixException {
        Matrix x = firstGuess.deleteColumns(Calculation.Ret.NEW, column);
        Matrix y = original.selectColumns(Calculation.Ret.NEW, column);
        ArrayList<Long> missingRows = new ArrayList<Long>();
        long i = y.getRowCount();
        while (--i >= 0L) {
            double v = y.getAsDouble(i, 0L);
            if (!MathUtil.isNaNOrInfinite(v)) continue;
            missingRows.add(i);
        }
        if (missingRows.isEmpty()) {
            return y;
        }
        Matrix xdel = x.deleteRows(Calculation.Ret.NEW, missingRows);
        Matrix bias1 = MatrixFactory.ones(xdel.getRowCount(), 1L);
        Matrix xtrain = MatrixFactory.horCat(xdel, bias1);
        Matrix ytrain = y.deleteRows(Calculation.Ret.NEW, missingRows);
        Matrix xinv = xtrain.pinv();
        Matrix b = xinv.mtimes(ytrain);
        Matrix bias2 = MatrixFactory.ones(x.getRowCount(), 1L);
        Matrix yPredicted = MatrixFactory.horCat(x, bias2).mtimes(b);
        int row = 0;
        while ((long)row < y.getRowCount()) {
            double v = y.getAsDouble(row, 0L);
            if (!Double.isNaN(v)) {
                yPredicted.setAsDouble(v, row, 0L);
            }
            ++row;
        }
        return yPredicted;
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    class PredictColumn
    implements Callable<Long> {
        long column = 0L;

        public PredictColumn(long column) {
            this.column = column;
        }

        @Override
        public Long call() throws Exception {
            Matrix newColumn = ImputeEM.replaceInColumn(ImputeEM.this.getSource(), ImputeEM.this.bestGuess, this.column);
            int r = 0;
            while ((long)r < newColumn.getRowCount()) {
                ImputeEM.this.imputed.setAsDouble(newColumn.getAsDouble(r, 0L), r, this.column);
                ++r;
            }
            return this.column;
        }
    }
}

