package weka.dl4j.iterators;

import java.io.File;
import java.net.URI;
import java.util.ArrayList;
import java.util.Random;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import weka.core.Instances;
import weka.core.OptionMetadata;
import weka.dl4j.EasyImageRecordReader;
import weka.dl4j.ScaleImagePixelsPreProcessor;
import weka.dl4j.SpecifiableFolderSplit;

/* loaded from: input_file:weka/dl4j/iterators/ImageDataSetIterator.class */
public class ImageDataSetIterator extends AbstractDataSetIterator {
    private static final long serialVersionUID = -3701309032945158130L;
    protected int m_height = 28;
    protected int m_width = 28;
    protected int m_numChannels = 1;
    protected File m_imagesLocation = new File(System.getProperty("user.dir"));

    @Override // weka.dl4j.iterators.AbstractDataSetIterator
    @OptionMetadata(displayName = "size of mini batch", description = "The mini batch size to use in the iterator (default = 1).", commandLineParamName = "bs", commandLineParamSynopsis = "-bs <int>", displayOrder = 0)
    public void setTrainBatchSize(int i) {
        this.m_batchSize = i;
    }

    @Override // weka.dl4j.iterators.AbstractDataSetIterator
    public int getTrainBatchSize() {
        return this.m_batchSize;
    }

    @OptionMetadata(displayName = "directory of images", description = "The directory containing the images (default = user home).", commandLineParamName = "imagesLocation", commandLineParamSynopsis = "-imagesLocation <string>", displayOrder = 1)
    public File getImagesLocation() {
        return this.m_imagesLocation;
    }

    public void setImagesLocation(File file) {
        this.m_imagesLocation = file;
    }

    @OptionMetadata(displayName = "desired width", description = "The desired width of the images (default = 28).", commandLineParamName = "width", commandLineParamSynopsis = "-width <int>", displayOrder = 2)
    public int getWidth() {
        return this.m_width;
    }

    public void setWidth(int i) {
        this.m_width = i;
    }

    @OptionMetadata(displayName = "desired height", description = "The desired height of the images (default = 28).", commandLineParamName = "height", commandLineParamSynopsis = "-height <int>", displayOrder = 3)
    public int getHeight() {
        return this.m_height;
    }

    public void setHeight(int i) {
        this.m_height = i;
    }

    @OptionMetadata(displayName = "desired number of channels", description = "The desired number of channels (default = 1).", commandLineParamName = "numChannels", commandLineParamSynopsis = "-numChannels <int>", displayOrder = 4)
    public int getNumChannels() {
        return this.m_numChannels;
    }

    public void setNumChannels(int i) {
        this.m_numChannels = i;
    }

    public void validate(Instances instances) throws Exception {
        if (!getImagesLocation().isDirectory()) {
            throw new Exception("Directory not valid: " + getImagesLocation());
        }
        if (!instances.attribute(0).isString() || instances.classIndex() != 1) {
            throw new Exception("An ARFF is required with a string attribute and a class attribute");
        }
    }

    @Override // weka.dl4j.iterators.AbstractDataSetIterator
    public int getNumAttributes(Instances instances) {
        return getNumChannels();
    }

    protected EasyImageRecordReader getImageRecordReader(Instances instances, int i) throws Exception {
        URI[] uriArr = new URI[instances.numInstances()];
        int i2 = 0;
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (int i3 = 0; i3 < instances.numInstances(); i3++) {
            String value = instances.attribute(0).value((int) instances.get(i3).value(0));
            arrayList.add(new File(getImagesLocation() + File.separator + value));
            arrayList2.add(String.valueOf(instances.get(i3).classValue()));
            i2 = (int) (i2 + new File(getImagesLocation() + File.separator + value).length());
        }
        EasyImageRecordReader easyImageRecordReader = new EasyImageRecordReader(getWidth(), getHeight(), getNumChannels(), arrayList, arrayList2, i);
        SpecifiableFolderSplit specifiableFolderSplit = new SpecifiableFolderSplit();
        specifiableFolderSplit.setFiles(uriArr);
        specifiableFolderSplit.setLength(i2);
        easyImageRecordReader.initialize(specifiableFolderSplit);
        return easyImageRecordReader;
    }

    @Override // weka.dl4j.iterators.AbstractDataSetIterator
    public DataSetIterator getIterator(Instances instances, int i, int i2) throws Exception {
        validate(instances);
        instances.randomize(new Random(i));
        RecordReaderDataSetIterator recordReaderDataSetIterator = new RecordReaderDataSetIterator(getImageRecordReader(instances, i), i2, -1, instances.numClasses());
        recordReaderDataSetIterator.setPreProcessor(new ScaleImagePixelsPreProcessor());
        return recordReaderDataSetIterator;
    }
}
