/*
 * Decompiled with CFR 0.152.
 */
package org.canova.image.recordreader;

import java.io.DataInputStream;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.net.URI;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.io.FileUtils;
import org.canova.api.conf.Configuration;
import org.canova.api.io.data.DoubleWritable;
import org.canova.api.records.reader.RecordReader;
import org.canova.api.split.FileSplit;
import org.canova.api.split.InputSplit;
import org.canova.api.split.InputStreamInputSplit;
import org.canova.api.writable.Writable;
import org.canova.common.RecordConverter;
import org.canova.image.loader.BaseImageLoader;
import org.canova.image.loader.ImageLoader;
import org.canova.image.loader.NativeImageLoader;
import org.nd4j.linalg.api.ndarray.INDArray;

public abstract class BaseImageRecordReader
implements RecordReader {
    protected Iterator<File> iter;
    protected Configuration conf;
    protected File currentFile;
    public List<String> labels = new ArrayList<String>();
    protected boolean appendLabel = false;
    protected Collection<Writable> record;
    protected boolean hitImage = false;
    protected int height = 28;
    protected int width = 28;
    protected int channels = 1;
    protected boolean cropImage = false;
    protected BaseImageLoader imageLoader;
    protected InputSplit inputSplit;
    protected Map<String, String> fileNameMap = new LinkedHashMap<String, String>();
    protected String pattern;
    protected int patternPosition = 0;
    public static final String HEIGHT = NAME_SPACE + ".height";
    public static final String WIDTH = NAME_SPACE + ".width";
    public static final String CHANNELS = NAME_SPACE + ".channels";
    public static final String CROP_IMAGE = NAME_SPACE + ".cropimage";
    public static final String IMAGE_LOADER = NAME_SPACE + ".imageloader";

    public BaseImageRecordReader() {
    }

    public BaseImageRecordReader(int height, int width, int channels, List<String> labels) {
        this(height, width, channels, false);
        this.labels = labels;
    }

    public BaseImageRecordReader(int height, int width, int channels, boolean appendLabel) {
        this.appendLabel = appendLabel;
        this.height = height;
        this.width = width;
        this.channels = channels;
    }

    public BaseImageRecordReader(int height, int width, int channels, boolean appendLabel, List<String> labels) {
        this(height, width, channels, appendLabel);
        this.labels = labels;
    }

    public BaseImageRecordReader(int height, int width, int channels, boolean appendLabel, String pattern, int patternPosition) {
        this(height, width, channels, appendLabel);
        this.pattern = pattern;
        this.patternPosition = patternPosition;
    }

    public BaseImageRecordReader(int height, int width, int channels, boolean appendLabel, List<String> labels, String pattern, int patternPosition) {
        this(height, width, channels, appendLabel, labels);
        this.pattern = pattern;
        this.patternPosition = patternPosition;
    }

    protected boolean containsFormat(String format) {
        for (String format2 : this.imageLoader.getAllowedFormats()) {
            if (!format.endsWith("." + format2)) continue;
            return true;
        }
        return false;
    }

    public void initialize(InputSplit split) throws IOException {
        if (this.imageLoader == null) {
            this.imageLoader = new NativeImageLoader(this.height, this.width, this.channels, this.cropImage);
        }
        this.inputSplit = split;
        if (split instanceof FileSplit) {
            URI[] locations = split.locations();
            if (locations != null && locations.length >= 1) {
                if (locations.length > 1) {
                    ArrayList<File> allFiles = new ArrayList<File>();
                    for (URI location : locations) {
                        File imgFile = new File(location);
                        if (!imgFile.isDirectory() && this.containsFormat(imgFile.getAbsolutePath())) {
                            allFiles.add(imgFile);
                        }
                        if (!this.appendLabel) continue;
                        File parentDir = imgFile.getParentFile();
                        String name = parentDir.getName();
                        if (!this.labels.contains(name)) {
                            this.labels.add(name);
                        }
                        if (this.pattern == null) continue;
                        String label = name.split(this.pattern)[this.patternPosition];
                        this.fileNameMap.put(imgFile.toString(), label);
                    }
                    this.iter = allFiles.listIterator();
                } else {
                    File curr = new File(locations[0]);
                    if (!curr.exists()) {
                        throw new IllegalArgumentException("Path " + curr.getAbsolutePath() + " does not exist!");
                    }
                    this.iter = curr.isDirectory() ? FileUtils.iterateFiles((File)curr, null, (boolean)true) : Collections.singletonList(curr).listIterator();
                }
            }
            FileSplit split1 = (FileSplit)split;
            this.labels.remove(split1.getRootDir());
        } else if (split instanceof InputStreamInputSplit) {
            InputStreamInputSplit split2 = (InputStreamInputSplit)split;
            InputStream is = split2.getIs();
            URI[] locations = split2.locations();
            INDArray load = this.imageLoader.asRowVector(is);
            this.record = RecordConverter.toRecord((INDArray)load);
            for (int i = 0; i < load.length(); ++i) {
                int label;
                if (!this.appendLabel) continue;
                Path path = Paths.get(locations[0]);
                String parent = path.getParent().toString();
                if (parent.contains("/")) {
                    parent = parent.substring(parent.lastIndexOf(47) + 1);
                }
                if ((label = this.labels.indexOf(parent)) >= 0) {
                    this.record.add((Writable)new DoubleWritable((double)this.labels.indexOf(parent)));
                    continue;
                }
                throw new IllegalStateException("Illegal label " + parent);
            }
            is.close();
        }
    }

    public void initialize(Configuration conf, InputSplit split) throws IOException, InterruptedException {
        this.appendLabel = conf.getBoolean(APPEND_LABEL, false);
        this.labels = new ArrayList<String>(conf.getStringCollection(LABELS));
        this.height = conf.getInt(HEIGHT, this.height);
        this.width = conf.getInt(WIDTH, this.width);
        this.channels = conf.getInt(CHANNELS, this.channels);
        this.cropImage = conf.getBoolean(CROP_IMAGE, this.cropImage);
        this.imageLoader = "imageio".equals(conf.get(IMAGE_LOADER)) ? new ImageLoader(this.height, this.width, this.channels, this.cropImage) : new NativeImageLoader(this.height, this.width, this.channels, this.cropImage);
        this.conf = conf;
        this.initialize(split);
    }

    public Collection<Writable> next() {
        if (this.iter != null) {
            File image;
            Collection<Object> ret = new ArrayList<Writable>();
            this.currentFile = image = this.iter.next();
            if (image.isDirectory()) {
                return this.next();
            }
            try {
                INDArray row = this.imageLoader.asRowVector(image);
                ret = RecordConverter.toRecord((INDArray)row);
                if (this.appendLabel) {
                    ret.add((Writable)new DoubleWritable((double)this.labels.indexOf(image.getParentFile().getName())));
                }
            }
            catch (Exception e) {
                e.printStackTrace();
            }
            return ret;
        }
        if (this.record != null) {
            this.hitImage = true;
            return this.record;
        }
        throw new IllegalStateException("No more elements");
    }

    public boolean hasNext() {
        if (this.iter != null) {
            return this.iter.hasNext();
        }
        if (this.record != null) {
            return !this.hitImage;
        }
        throw new IllegalStateException("Indeterminant state: record must not be null, or a file iterator must exist");
    }

    public void close() throws IOException {
    }

    public void setConf(Configuration conf) {
        this.conf = conf;
    }

    public Configuration getConf() {
        return this.conf;
    }

    public String getLabel(String path) {
        if (this.fileNameMap != null && this.fileNameMap.containsKey(path)) {
            return this.fileNameMap.get(path);
        }
        return new File(path).getParentFile().getName();
    }

    protected void accumulateLabel(String path) {
        String name = this.getLabel(path);
        if (!this.labels.contains(name)) {
            this.labels.add(name);
        }
    }

    public File getCurrentFile() {
        return this.currentFile;
    }

    public void setCurrentFile(File currentFile) {
        this.currentFile = currentFile;
    }

    public List<String> getLabels() {
        return this.labels;
    }

    public void reset() {
        if (this.inputSplit == null) {
            throw new UnsupportedOperationException("Cannot reset without first initializing");
        }
        try {
            this.initialize(this.inputSplit);
        }
        catch (Exception e) {
            throw new RuntimeException("Error during LineRecordReader reset", e);
        }
    }

    public int numLabels() {
        return this.labels.size();
    }

    public Collection<Writable> record(URI uri, DataInputStream dataInputStream) throws IOException {
        INDArray row = this.imageLoader.asRowVector(dataInputStream);
        Collection ret = RecordConverter.toRecord((INDArray)row);
        if (this.appendLabel) {
            ret.add(new DoubleWritable((double)this.labels.indexOf(this.getLabel(uri.getPath()))));
        }
        return ret;
    }
}

