/*
 * Decompiled with CFR 0.152.
 */
package weka.dl4j.iterators;

import java.io.File;
import java.net.URI;
import java.util.ArrayList;
import java.util.Random;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.split.InputSplit;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import weka.core.Instances;
import weka.core.OptionMetadata;
import weka.dl4j.EasyImageRecordReader;
import weka.dl4j.ScaleImagePixelsPreProcessor;
import weka.dl4j.SpecifiableFolderSplit;
import weka.dl4j.iterators.AbstractDataSetIterator;

public class ImageDataSetIterator
extends AbstractDataSetIterator {
    private static final long serialVersionUID = -3701309032945158130L;
    protected int m_height = 28;
    protected int m_width = 28;
    protected int m_numChannels = 1;
    protected File m_imagesLocation = new File(System.getProperty("user.dir"));

    @Override
    @OptionMetadata(displayName="size of mini batch", description="The mini batch size to use in the iterator (default = 1).", commandLineParamName="bs", commandLineParamSynopsis="-bs <int>", displayOrder=0)
    public void setTrainBatchSize(int trainBatchSize) {
        this.m_batchSize = trainBatchSize;
    }

    @Override
    public int getTrainBatchSize() {
        return this.m_batchSize;
    }

    @OptionMetadata(displayName="directory of images", description="The directory containing the images (default = user home).", commandLineParamName="imagesLocation", commandLineParamSynopsis="-imagesLocation <string>", displayOrder=1)
    public File getImagesLocation() {
        return this.m_imagesLocation;
    }

    public void setImagesLocation(File imagesLocation) {
        this.m_imagesLocation = imagesLocation;
    }

    @OptionMetadata(displayName="desired width", description="The desired width of the images (default = 28).", commandLineParamName="width", commandLineParamSynopsis="-width <int>", displayOrder=2)
    public int getWidth() {
        return this.m_width;
    }

    public void setWidth(int width) {
        this.m_width = width;
    }

    @OptionMetadata(displayName="desired height", description="The desired height of the images (default = 28).", commandLineParamName="height", commandLineParamSynopsis="-height <int>", displayOrder=3)
    public int getHeight() {
        return this.m_height;
    }

    public void setHeight(int height) {
        this.m_height = height;
    }

    @OptionMetadata(displayName="desired number of channels", description="The desired number of channels (default = 1).", commandLineParamName="numChannels", commandLineParamSynopsis="-numChannels <int>", displayOrder=4)
    public int getNumChannels() {
        return this.m_numChannels;
    }

    public void setNumChannels(int numChannels) {
        this.m_numChannels = numChannels;
    }

    public void validate(Instances data) throws Exception {
        if (!this.getImagesLocation().isDirectory()) {
            throw new Exception("Directory not valid: " + this.getImagesLocation());
        }
        if (!data.attribute(0).isString() || data.classIndex() != 1) {
            throw new Exception("An ARFF is required with a string attribute and a class attribute");
        }
    }

    @Override
    public int getNumAttributes(Instances data) {
        return this.getNumChannels();
    }

    protected EasyImageRecordReader getImageRecordReader(Instances data, int seed) throws Exception {
        URI[] locations = new URI[data.numInstances()];
        int len = 0;
        ArrayList<File> filenames = new ArrayList<File>();
        ArrayList<String> classes = new ArrayList<String>();
        for (int x = 0; x < data.numInstances(); ++x) {
            String location = data.attribute(0).value((int)data.get(x).value(0));
            filenames.add(new File(this.getImagesLocation() + File.separator + location));
            classes.add(String.valueOf(data.get(x).classValue()));
            File f = new File(this.getImagesLocation() + File.separator + location);
            len = (int)((long)len + f.length());
        }
        EasyImageRecordReader reader = new EasyImageRecordReader(this.getWidth(), this.getHeight(), this.getNumChannels(), filenames, classes, seed);
        SpecifiableFolderSplit fs = new SpecifiableFolderSplit();
        fs.setFiles(locations);
        fs.setLength(len);
        reader.initialize((InputSplit)fs);
        return reader;
    }

    @Override
    public DataSetIterator getIterator(Instances data, int seed, int batchSize) throws Exception {
        this.validate(data);
        data.randomize(new Random(seed));
        EasyImageRecordReader reader = this.getImageRecordReader(data, seed);
        RecordReaderDataSetIterator tmpIter = new RecordReaderDataSetIterator((RecordReader)reader, batchSize, -1, data.numClasses());
        tmpIter.setPreProcessor((DataSetPreProcessor)new ScaleImagePixelsPreProcessor());
        return tmpIter;
    }
}

