/*
 * Decompiled with CFR 0.152.
 */
package org.canova.image.recordreader;

import com.twelvemonkeys.imageio.plugins.bmp.BMPImageReaderSpi;
import com.twelvemonkeys.imageio.plugins.bmp.CURImageReaderSpi;
import com.twelvemonkeys.imageio.plugins.bmp.ICOImageReaderSpi;
import com.twelvemonkeys.imageio.plugins.jpeg.JPEGImageReaderSpi;
import com.twelvemonkeys.imageio.plugins.jpeg.JPEGImageWriterSpi;
import com.twelvemonkeys.imageio.plugins.psd.PSDImageReaderSpi;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.net.URI;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import javax.imageio.ImageIO;
import javax.imageio.spi.IIORegistry;
import org.apache.commons.io.FileUtils;
import org.canova.api.conf.Configuration;
import org.canova.api.io.data.DoubleWritable;
import org.canova.api.io.data.Text;
import org.canova.api.records.reader.RecordReader;
import org.canova.api.split.FileSplit;
import org.canova.api.split.InputSplit;
import org.canova.api.split.InputStreamInputSplit;
import org.canova.api.writable.Writable;
import org.canova.common.RecordConverter;
import org.canova.image.loader.ImageLoader;
import org.nd4j.linalg.api.ndarray.INDArray;

public abstract class BaseImageRecordReader
implements RecordReader {
    protected Iterator<File> iter;
    protected Configuration conf;
    protected File currentFile;
    public List<String> labels = new ArrayList<String>();
    protected boolean appendLabel = false;
    protected Collection<Writable> record;
    protected final List<String> allowedFormats = Arrays.asList("tif", "jpg", "png", "jpeg", "bmp", "JPEG", "JPG", "TIF", "PNG");
    protected boolean hitImage = false;
    protected ImageLoader imageLoader;
    protected InputSplit inputSplit;
    protected Map<String, String> fileNameMap = new LinkedHashMap<String, String>();
    protected String pattern;
    protected int patternPosition = 0;
    public static final String WIDTH = NAME_SPACE + ".width";
    public static final String HEIGHT = NAME_SPACE + ".height";
    public static final String CHANNELS = NAME_SPACE + ".channels";

    public BaseImageRecordReader() {
    }

    public BaseImageRecordReader(int width, int height, int channels, List<String> labels) {
        this(width, height, channels, false);
        this.labels = labels;
    }

    public BaseImageRecordReader(int width, int height, int channels, boolean appendLabel) {
        this.appendLabel = appendLabel;
        this.imageLoader = new ImageLoader(width, height, channels);
    }

    public BaseImageRecordReader(int width, int height, int channels, boolean appendLabel, List<String> labels) {
        this(width, height, channels, appendLabel);
        this.labels = labels;
    }

    public BaseImageRecordReader(int width, int height, int channels, boolean appendLabel, String pattern, int patternPosition) {
        this(width, height, channels, appendLabel);
        this.pattern = pattern;
        this.patternPosition = patternPosition;
    }

    public BaseImageRecordReader(int width, int height, int channels, boolean appendLabel, List<String> labels, String pattern, int patternPosition) {
        this(width, height, channels, appendLabel, labels);
        this.pattern = pattern;
        this.patternPosition = patternPosition;
    }

    protected boolean containsFormat(String format) {
        for (String format2 : this.allowedFormats) {
            if (!format.endsWith("." + format2)) continue;
            return true;
        }
        return false;
    }

    public void initialize(InputSplit split) throws IOException {
        this.inputSplit = split;
        if (split instanceof FileSplit) {
            URI[] locations = split.locations();
            if (locations != null && locations.length >= 1) {
                if (locations.length > 1) {
                    ArrayList<File> allFiles = new ArrayList<File>();
                    for (URI location : locations) {
                        File imgFile = new File(location);
                        if (!imgFile.isDirectory() && this.containsFormat(imgFile.getAbsolutePath())) {
                            allFiles.add(imgFile);
                        }
                        if (!this.appendLabel) continue;
                        File parentDir = imgFile.getParentFile();
                        String name = parentDir.getName();
                        if (!this.labels.contains(name)) {
                            this.labels.add(name);
                        }
                        if (this.pattern == null) continue;
                        String label = name.split(this.pattern)[this.patternPosition];
                        this.fileNameMap.put(imgFile.toString(), label);
                    }
                    this.iter = allFiles.listIterator();
                } else {
                    File curr = new File(locations[0]);
                    if (!curr.exists()) {
                        throw new IllegalArgumentException("Path " + curr.getAbsolutePath() + " does not exist!");
                    }
                    this.iter = curr.isDirectory() ? FileUtils.iterateFiles((File)curr, null, (boolean)true) : Collections.singletonList(curr).listIterator();
                }
            }
            FileSplit split1 = (FileSplit)split;
            this.labels.remove(split1.getRootDir());
        } else if (split instanceof InputStreamInputSplit) {
            InputStreamInputSplit split2 = (InputStreamInputSplit)split;
            InputStream is = split2.getIs();
            URI[] locations = split2.locations();
            INDArray load = this.imageLoader.asRowVector(is);
            this.record = RecordConverter.toRecord((INDArray)load);
            for (int i = 0; i < load.length(); ++i) {
                int label;
                if (!this.appendLabel) continue;
                Path path = Paths.get(locations[0]);
                String parent = path.getParent().toString();
                if (parent.contains("/")) {
                    parent = parent.substring(parent.lastIndexOf(47) + 1);
                }
                if ((label = this.labels.indexOf(parent)) >= 0) {
                    this.record.add((Writable)new DoubleWritable((double)this.labels.indexOf(parent)));
                    continue;
                }
                throw new IllegalStateException("Illegal label " + parent);
            }
            is.close();
        }
    }

    public void initialize(Configuration conf, InputSplit split) throws IOException, InterruptedException {
        this.appendLabel = conf.getBoolean(APPEND_LABEL, false);
        this.labels = new ArrayList<String>(conf.getStringCollection(LABELS));
        this.imageLoader = new ImageLoader(conf.getInt(WIDTH, 28), conf.getInt(HEIGHT, 28), conf.getInt(CHANNELS, 1));
        this.conf = conf;
        this.initialize(split);
    }

    public Collection<Writable> next() {
        if (this.iter != null) {
            File image;
            Collection<Object> ret = new ArrayList<Writable>();
            this.currentFile = image = this.iter.next();
            if (image.isDirectory()) {
                return this.next();
            }
            try {
                BufferedImage bimg = ImageIO.read(image);
                INDArray row = this.imageLoader.asRowVector(bimg);
                ret = RecordConverter.toRecord((INDArray)row);
                if (this.appendLabel) {
                    ret.add(new DoubleWritable((double)this.labels.indexOf(image.getParentFile().getName())));
                }
            }
            catch (Exception e) {
                e.printStackTrace();
            }
            if (this.iter.hasNext()) {
                return ret;
            }
            if (this.iter.hasNext()) {
                try {
                    ret.add((Writable)new Text(FileUtils.readFileToString((File)this.iter.next())));
                }
                catch (IOException e) {
                    e.printStackTrace();
                }
            }
            return ret;
        }
        if (this.record != null) {
            this.hitImage = true;
            return this.record;
        }
        throw new IllegalStateException("No more elements");
    }

    public boolean hasNext() {
        if (this.iter != null) {
            return this.iter.hasNext();
        }
        if (this.record != null) {
            return !this.hitImage;
        }
        throw new IllegalStateException("Indeterminant state: record must not be null, or a file iterator must exist");
    }

    public void close() throws IOException {
    }

    public void setConf(Configuration conf) {
        this.conf = conf;
    }

    public Configuration getConf() {
        return this.conf;
    }

    public String getLabel(String path) {
        return this.fileNameMap.get(path);
    }

    protected void accumulateLabel(String path) {
        String name = this.getLabel(path);
        if (!this.labels.contains(name)) {
            this.labels.add(name);
        }
    }

    public File getCurrentFile() {
        return this.currentFile;
    }

    public void setCurrentFile(File currentFile) {
        this.currentFile = currentFile;
    }

    public List<String> getLabels() {
        return this.labels;
    }

    public void reset() {
        if (this.inputSplit == null) {
            throw new UnsupportedOperationException("Cannot reset without first initializing");
        }
        try {
            this.initialize(this.inputSplit);
        }
        catch (Exception e) {
            throw new RuntimeException("Error during LineRecordReader reset", e);
        }
    }

    public int numLabels() {
        return this.labels.size();
    }

    static {
        ImageIO.scanForPlugins();
        IIORegistry.getDefaultInstance().registerServiceProvider(new JPEGImageReaderSpi());
        IIORegistry.getDefaultInstance().registerServiceProvider(new JPEGImageWriterSpi());
        IIORegistry.getDefaultInstance().registerServiceProvider(new PSDImageReaderSpi());
        IIORegistry.getDefaultInstance().registerServiceProvider(Arrays.asList(new BMPImageReaderSpi(), new CURImageReaderSpi(), new ICOImageReaderSpi()));
    }
}

