/*
 * Decompiled with CFR 0.152.
 */
package org.canova.image.recordreader;

import java.io.DataInputStream;
import java.io.File;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.Reader;
import java.net.URI;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import org.apache.commons.io.IOUtils;
import org.canova.api.conf.Configuration;
import org.canova.api.io.data.DoubleWritable;
import org.canova.api.records.reader.RecordReader;
import org.canova.api.split.InputSplit;
import org.canova.api.writable.Writable;
import org.canova.common.data.NDArrayWritable;
import org.canova.image.mnist.MnistManager;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.linalg.util.FeatureUtil;
import org.nd4j.linalg.util.NDArrayUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class MNISTRecordReader
implements RecordReader {
    private static Logger log = LoggerFactory.getLogger(MNISTRecordReader.class);
    private URI[] locations;
    private int currIndex = 0;
    private Iterator<String> iter;
    private transient MnistManager man = new MnistManager(MNIST_ROOT + "images-idx1-ubyte", MNIST_ROOT + "labels-idx1-ubyte");
    public static final int NUM_EXAMPLES = 60000;
    private int numOutcomes = 10;
    private int totalExamples = 60000;
    private int cursor = 1;
    private int inputColumns = 0;
    protected DataSet curr = null;
    private File fileDir;
    private static final String trainingFilesURL = "http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz";
    private static final String trainingFilesFilename = "images-idx1-ubyte.gz";
    public static final String trainingFilesFilename_unzipped = "images-idx1-ubyte";
    private static final String trainingFileLabelsURL = "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz";
    private static final String trainingFileLabelsFilename = "labels-idx1-ubyte.gz";
    public static final String trainingFileLabelsFilename_unzipped = "labels-idx1-ubyte";
    private static final String LOCAL_DIR_NAME = "MNIST";
    private static final String TEMP_ROOT = System.getProperty("user.home");
    private static final String MNIST_ROOT = TEMP_ROOT + File.separator + "MNIST" + File.separator;
    private boolean binarize = this.binarize;
    protected InputSplit inputSplit;

    public MNISTRecordReader() throws IOException {
        int[][] image;
        this.man.setCurrent(this.cursor);
        try {
            image = this.man.readImage();
        }
        catch (IOException e) {
            throw new IllegalStateException("Unable to read image");
        }
        this.inputColumns = ArrayUtil.flatten((int[][])image).length;
    }

    public void initialize(InputSplit split) throws IOException, InterruptedException {
        this.inputSplit = split;
        this.locations = split.locations();
        if (this.locations != null && this.locations.length > 0) {
            this.iter = IOUtils.lineIterator((Reader)new InputStreamReader(this.locations[0].toURL().openStream()));
        }
    }

    public void initialize(Configuration conf, InputSplit split) throws IOException, InterruptedException {
        this.initialize(split);
    }

    public Collection<Writable> next() {
        if (!this.fetchNext()) {
            return null;
        }
        DataSet currentRecord = this.curr;
        ArrayList<Writable> ret = new ArrayList<Writable>();
        INDArray data = currentRecord.get(0).getFeatureMatrix();
        INDArray labels = currentRecord.get(0).getLabels();
        ret.add((Writable)new NDArrayWritable(data));
        for (int i = 0; i < labels.length(); ++i) {
            if (!(labels.getDouble(i) > 0.0)) continue;
            ret.add((Writable)new DoubleWritable((double)i));
            break;
        }
        return ret;
    }

    public boolean hasNext() {
        return this.cursor < this.totalExamples;
    }

    public void close() {
    }

    public void setConf(Configuration conf) {
    }

    public Configuration getConf() {
        return null;
    }

    public boolean fetchNext() {
        if (!this.hasNext()) {
            return false;
        }
        ArrayList<DataSet> toConvert = new ArrayList<DataSet>();
        this.man.setCurrent(this.cursor);
        try {
            INDArray in = NDArrayUtil.toNDArray((int[])ArrayUtil.flatten((int[][])this.man.readImage()));
            if (this.binarize) {
                for (int d = 0; d < in.length(); ++d) {
                    if (!this.binarize) continue;
                    if (in.getDouble(d) > 30.0) {
                        in.putScalar(d, 1);
                        continue;
                    }
                    in.putScalar(d, 0);
                }
            } else {
                in.divi((Number)255);
            }
            INDArray out = this.createOutputVector(this.man.readLabel());
            boolean found = false;
            for (int col = 0; col < out.length(); ++col) {
                if (!(out.getDouble(col) > 0.0)) continue;
                found = true;
                break;
            }
            if (!found) {
                throw new IllegalStateException("Found a matrix without an outcome");
            }
            toConvert.add(new DataSet(in, out));
        }
        catch (IOException e) {
            throw new IllegalStateException("Unable to read image");
        }
        ++this.cursor;
        this.initializeCurrFromList(toConvert);
        return true;
    }

    protected INDArray createOutputVector(int outcomeLabel) {
        return FeatureUtil.toOutcomeVector((int)outcomeLabel, (int)this.numOutcomes);
    }

    protected INDArray createInputMatrix(int numRows) {
        return Nd4j.create((int)numRows, (int)this.inputColumns);
    }

    protected INDArray createOutputMatrix(int numRows) {
        return Nd4j.create((int)numRows, (int)this.numOutcomes);
    }

    protected void initializeCurrFromList(List<DataSet> examples) {
        if (examples.isEmpty()) {
            log.warn("Warning: empty dataset from the fetcher");
        }
        this.curr = null;
        INDArray inputs = this.createInputMatrix(examples.size());
        INDArray labels = this.createOutputMatrix(examples.size());
        for (int i = 0; i < examples.size(); ++i) {
            INDArray data = examples.get(i).getFeatureMatrix();
            INDArray label = examples.get(i).getLabels();
            inputs.putRow(i, data);
            labels.putRow(i, label);
        }
        this.curr = new DataSet(inputs, labels);
        examples.clear();
    }

    public List<String> getLabels() {
        return null;
    }

    public void reset() {
    }

    public Collection<Writable> record(URI uri, DataInputStream dataInputStream) throws IOException {
        throw new UnsupportedOperationException("Loading MNIST data via DataInputStream not supported.");
    }
}

