/*
 * Decompiled with CFR 0.152.
 */
package org.canova.image.recordreader;

import java.awt.image.BufferedImage;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.net.URI;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;
import javax.imageio.ImageIO;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils;
import org.canova.api.io.data.DoubleWritable;
import org.canova.api.io.data.Text;
import org.canova.api.split.FileSplit;
import org.canova.api.split.InputSplit;
import org.canova.api.writable.Writable;
import org.canova.common.RecordConverter;
import org.canova.image.loader.ImageLoader;
import org.canova.image.recordreader.BaseImageRecordReader;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ImageNetRecordReader
extends BaseImageRecordReader {
    protected static Logger log = LoggerFactory.getLogger(ImageNetRecordReader.class);
    protected Map<String, String> labelFileIdMap = new LinkedHashMap<String, String>();
    protected String labelPath;
    protected String fileNameMapPath = null;
    protected boolean eval = false;

    public ImageNetRecordReader(int width, int height, int channels, String labelPath) {
        this.imageLoader = new ImageLoader(width, height, channels);
        this.labelPath = labelPath;
    }

    public ImageNetRecordReader(int width, int height, int channels, String labelPath, boolean appendLabel) {
        this.imageLoader = new ImageLoader(width, height, channels);
        this.labelPath = labelPath;
        this.appendLabel = appendLabel;
    }

    public ImageNetRecordReader(int width, int height, int channels, String labelPath, boolean appendLabel, String pattern) {
        this.imageLoader = new ImageLoader(width, height, channels);
        this.labelPath = labelPath;
        this.appendLabel = appendLabel;
        this.pattern = pattern;
    }

    public ImageNetRecordReader(int width, int height, int channels, String labelPath, boolean appendLabel, String pattern, int patternPosition) {
        this.imageLoader = new ImageLoader(width, height, channels);
        this.labelPath = labelPath;
        this.appendLabel = appendLabel;
        this.pattern = pattern;
        this.patternPosition = patternPosition;
    }

    public ImageNetRecordReader(int width, int height, int channels, String labelPath, String fileNameMapPath, boolean appendLabel) {
        this.imageLoader = new ImageLoader(width, height, channels);
        this.labelPath = labelPath;
        this.appendLabel = appendLabel;
        this.fileNameMapPath = fileNameMapPath;
        this.eval = true;
    }

    public ImageNetRecordReader(int width, int height, int channels, String labelPath, String fileNameMapPath, boolean appendLabel, String pattern, int patternPosition) {
        this.imageLoader = new ImageLoader(width, height, channels);
        this.labelPath = labelPath;
        this.appendLabel = appendLabel;
        this.fileNameMapPath = fileNameMapPath;
        this.pattern = pattern;
        this.patternPosition = patternPosition;
        this.eval = true;
    }

    private Map<String, String> defineLabels(String path) throws IOException {
        String line;
        LinkedHashMap<String, String> tmpMap = new LinkedHashMap<String, String>();
        BufferedReader br = new BufferedReader(new FileReader(path));
        while ((line = br.readLine()) != null) {
            String[] row = line.split(",");
            tmpMap.put(row[0], row[1]);
        }
        return tmpMap;
    }

    @Override
    public void initialize(InputSplit split) throws IOException {
        this.inputSplit = split;
        if (this.labelPath != null && this.labelFileIdMap.isEmpty()) {
            this.labelFileIdMap = this.defineLabels(this.labelPath);
            this.labels = new ArrayList<String>(this.labelFileIdMap.values());
        }
        if (this.fileNameMapPath != null && this.fileNameMap.isEmpty()) {
            this.fileNameMap = this.defineLabels(this.fileNameMapPath);
        }
        if (split instanceof FileSplit) {
            URI[] locations = split.locations();
            if (locations != null && locations.length >= 1) {
                if (locations.length > 1) {
                    ArrayList<File> allFiles = new ArrayList<File>();
                    for (URI location : locations) {
                        File iter = new File(location);
                        if (iter.isDirectory() || !this.containsFormat(iter.getAbsolutePath())) continue;
                        allFiles.add(iter);
                    }
                    this.iter = allFiles.listIterator();
                } else {
                    File curr = new File(locations[0]);
                    if (!curr.exists()) {
                        throw new IllegalArgumentException("Path " + curr.getAbsolutePath() + " does not exist!");
                    }
                    this.iter = curr.isDirectory() ? FileUtils.iterateFiles((File)curr, null, (boolean)true) : Collections.singletonList(curr).listIterator();
                }
            }
        } else {
            throw new UnsupportedClassVersionError("Split needs to be an instance of FileSplit for this record reader.");
        }
    }

    @Override
    public Collection<Writable> next() {
        if (this.iter != null) {
            Collection<Object> ret = new ArrayList<Writable>();
            File image = (File)this.iter.next();
            if (image.isDirectory()) {
                return this.next();
            }
            try {
                int labelId = -1;
                BufferedImage bimg = this.imageLoader.centerCropIfNeeded(ImageIO.read(image));
                INDArray row = this.imageLoader.asRowVector(bimg);
                ret = RecordConverter.toRecord((INDArray)row);
                if (this.appendLabel && this.fileNameMapPath == null) {
                    String WNID = FilenameUtils.getBaseName((String)image.getName()).split(this.pattern)[this.patternPosition];
                    labelId = this.labels.indexOf(this.labelFileIdMap.get(WNID));
                } else if (this.eval) {
                    String fileName = FilenameUtils.getName((String)image.getName());
                    labelId = this.labels.indexOf(this.labelFileIdMap.get(this.fileNameMap.get(fileName)));
                }
                if (labelId < 0) {
                    throw new IllegalStateException("Illegal label " + labelId);
                }
                ret.add((Writable)new DoubleWritable((double)labelId));
            }
            catch (Exception e) {
                e.printStackTrace();
            }
            if (this.iter.hasNext()) {
                return ret;
            }
            if (this.iter.hasNext()) {
                try {
                    ret.add((Writable)new Text(FileUtils.readFileToString((File)((File)this.iter.next()))));
                }
                catch (IOException e) {
                    e.printStackTrace();
                }
            }
            return ret;
        }
        if (this.record != null) {
            this.hitImage = true;
            return this.record;
        }
        throw new IllegalStateException("No more elements");
    }
}

