/*
 * Decompiled with CFR 0.152.
 */
package org.openimaj.ml.linear.data;

import com.jmatio.io.MatFileReader;
import com.jmatio.types.MLArray;
import com.jmatio.types.MLCell;
import com.jmatio.types.MLChar;
import com.jmatio.types.MLDouble;
import com.jmatio.types.MLSparse;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.mtj.SparseMatrix;
import gov.sandia.cognition.math.matrix.mtj.SparseMatrixFactoryMTJ;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.openimaj.ml.linear.data.MatrixDataGenerator;
import org.openimaj.util.filter.FilterUtils;
import org.openimaj.util.function.Predicate;
import org.openimaj.util.pair.Pair;

public class BillMatlabFileDataGenerator
implements MatrixDataGenerator<Matrix> {
    private Map<String, MLArray> content;
    private List<Fold> folds;
    private int ndays;
    private int nusers;
    private int nwords;
    private List<Matrix> dayWords;
    private List<Matrix> dayPolls;
    private int currentIndex;
    private int ntasks;
    private int[] indexes;
    private Map<Integer, String> voc;
    private String[] tasks;
    private Set<Integer> keepIndex;
    private Map<Integer, Integer> indexToVoc;
    private boolean filter;
    String mainMatrixKey = "user_vsr_for_polls";

    public BillMatlabFileDataGenerator(File matfile, int ndays, boolean filter) throws IOException {
        MatFileReader reader = new MatFileReader(matfile);
        this.ndays = ndays;
        this.content = reader.getContent();
        this.currentIndex = 0;
        this.filter = filter;
        this.prepareVocabulary();
        this.prepareFolds();
        this.prepareDayUserWords();
        this.prepareDayPolls();
    }

    public BillMatlabFileDataGenerator(File matfile, String mainMatrixName, File polls, int ndays, boolean filter) throws IOException {
        MatFileReader reader = new MatFileReader(matfile);
        this.mainMatrixKey = mainMatrixName;
        this.ndays = ndays;
        this.content = reader.getContent();
        this.currentIndex = 0;
        this.filter = filter;
        this.prepareVocabulary();
        this.prepareFolds();
        this.prepareDayUserWords();
        reader = new MatFileReader(polls);
        this.content = reader.getContent();
        this.prepareDayPolls();
        this.content = null;
    }

    public BillMatlabFileDataGenerator(File matfile, String mainMatrixName, File polls, int ndays, boolean filter, List<Fold> folds) throws IOException {
        MatFileReader reader = new MatFileReader(matfile);
        this.mainMatrixKey = mainMatrixName;
        this.ndays = ndays;
        this.content = reader.getContent();
        this.currentIndex = 0;
        this.filter = filter;
        this.prepareVocabulary();
        this.folds = folds;
        this.prepareDayUserWords();
        reader = new MatFileReader(polls);
        this.content = reader.getContent();
        this.prepareDayPolls();
        this.content = null;
    }

    public Map<Integer, String> getVocabulary() {
        return this.voc;
    }

    private void prepareVocabulary() {
        MLCell vocLoaded;
        this.keepIndex = new HashSet<Integer>();
        MLDouble keepIndex = (MLDouble)this.content.get("voc_keep_terms_index");
        if (keepIndex != null) {
            double[] filterIndexArr;
            for (double d : filterIndexArr = keepIndex.getArray()[0]) {
                this.keepIndex.add((int)d - 1);
            }
        }
        if ((vocLoaded = (MLCell)this.content.get("voc")) != null) {
            this.indexToVoc = new HashMap<Integer, Integer>();
            ArrayList vocArr = vocLoaded.cells();
            int index = 0;
            int vocIndex = 0;
            this.voc = new HashMap<Integer, String>();
            for (MLArray vocArrItem : vocArr) {
                MLChar vocChar = (MLChar)vocArrItem;
                String vocString = vocChar.getString(0);
                if (this.filter && this.keepIndex.contains(index)) {
                    this.voc.put(vocIndex, vocString);
                    this.indexToVoc.put(index, vocIndex);
                    ++vocIndex;
                }
                ++index;
            }
        }
    }

    public void setFold(int fold, Mode mode) {
        if (fold == -1) {
            this.indexes = new int[this.dayWords.size()];
            for (int i = 0; i < this.indexes.length; ++i) {
                this.indexes[i] = i;
            }
        } else {
            Fold f = this.folds.get(fold);
            this.indexes = mode.indexes(f);
        }
        this.currentIndex = 0;
    }

    private void prepareDayPolls() {
        ArrayList pollKeys = FilterUtils.filter(this.content.keySet(), (Predicate)new Predicate<String>(){

            public boolean test(String object) {
                return object.endsWith("per_unique_extended");
            }
        });
        this.ntasks = pollKeys.size();
        this.dayPolls = new ArrayList<Matrix>();
        for (int i = 0; i < this.ndays; ++i) {
            this.dayPolls.add((Matrix)SparseMatrixFactoryMTJ.INSTANCE.createMatrix(1, this.ntasks));
        }
        this.tasks = new String[this.ntasks];
        for (int t = 0; t < this.ntasks; ++t) {
            String pollKey;
            this.tasks[t] = pollKey = (String)pollKeys.get(t);
            MLDouble arr = (MLDouble)this.content.get(pollKey);
            for (int i = 0; i < this.ndays; ++i) {
                Matrix dayPoll = this.dayPolls.get(i);
                dayPoll.setElement(0, t, ((Double)arr.get(i, 0)).doubleValue());
            }
        }
    }

    public String[] getTasks() {
        return this.tasks;
    }

    private void prepareDayUserWords() {
        int i;
        MLSparse arr = (MLSparse)this.content.get(this.mainMatrixKey);
        Double[] realVals = arr.exportReal();
        int[] rows = arr.getIR();
        int[] cols = arr.getIC();
        this.nwords = this.voc == null ? arr.getN() : this.voc.size();
        this.nusers = arr.getM() / this.ndays;
        this.dayWords = new ArrayList<Matrix>();
        for (i = 0; i < this.ndays; ++i) {
            SparseMatrix userWord = SparseMatrixFactoryMTJ.INSTANCE.createMatrix(this.nwords, this.nusers);
            this.dayWords.add((Matrix)userWord);
        }
        for (i = 0; i < rows.length; ++i) {
            if (this.filter && !this.keepIndex.contains(cols[i])) continue;
            int wordIndex = cols[i];
            if (this.indexToVoc != null) {
                wordIndex = this.indexToVoc.get(wordIndex);
            }
            int dayIndex = rows[i] / this.nusers;
            int userIndex = rows[i] - dayIndex * this.nusers;
            this.dayWords.get(dayIndex).setElement(wordIndex, userIndex, realVals[i].doubleValue());
        }
    }

    private void prepareFolds() {
        MLArray setfolds = this.content.get("set_fold");
        if (setfolds == null) {
            return;
        }
        if (setfolds.isCell()) {
            this.folds = new ArrayList<Fold>();
            MLCell foldcells = (MLCell)setfolds;
            int nfolds = foldcells.getM();
            System.out.println(String.format("Found %d folds", nfolds));
            for (int i = 0; i < nfolds; ++i) {
                MLDouble training = (MLDouble)foldcells.get(i, 0);
                MLDouble test = (MLDouble)foldcells.get(i, 1);
                MLDouble validation = (MLDouble)foldcells.get(i, 2);
                Fold f = new Fold(this.toIntArray(training), this.toIntArray(test), this.toIntArray(validation));
                this.folds.add(f);
            }
        } else {
            throw new RuntimeException("Can't find set_folds in expected format");
        }
    }

    private int[] toIntArray(MLDouble training) {
        int[] arr = new int[training.getN()];
        for (int i = 0; i < arr.length; ++i) {
            arr[i] = ((Double)training.get(0, i)).intValue();
        }
        return arr;
    }

    @Override
    public Pair<Matrix> generate() {
        if (this.currentIndex >= this.indexes.length) {
            return null;
        }
        int dayIndex = this.indexes[this.currentIndex];
        Pair pair = new Pair((Object)this.dayWords.get(dayIndex), (Object)this.dayPolls.get(dayIndex));
        ++this.currentIndex;
        return pair;
    }

    public int nFolds() {
        return this.folds.size();
    }

    public List<Pair<Matrix>> generateAll() {
        Pair<Matrix> pair;
        ArrayList<Pair<Matrix>> ret = new ArrayList<Pair<Matrix>>();
        while ((pair = this.generate()) != null) {
            ret.add(pair);
        }
        return ret;
    }

    public static enum Mode {
        TRAINING{

            @Override
            public int[] indexes(Fold fold) {
                return fold.training;
            }
        }
        ,
        TEST{

            @Override
            public int[] indexes(Fold fold) {
                return fold.test;
            }
        }
        ,
        VALIDATION{

            @Override
            public int[] indexes(Fold fold) {
                return fold.validation;
            }
        }
        ,
        ALL{

            @Override
            public int[] indexes(Fold fold) {
                return null;
            }
        };


        public abstract int[] indexes(Fold var1);
    }

    public static class Fold {
        int[] training;
        int[] test;
        int[] validation;

        public Fold(int[] training, int[] test, int[] validation) {
            this.training = training;
            this.test = test;
            this.validation = validation;
        }
    }
}

