/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.arbiter.data;

import java.io.IOException;
import java.util.Map;
import java.util.Random;
import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultipleEpochsIterator;
import org.nd4j.shade.jackson.annotation.JsonProperty;

public class MnistDataProvider
implements DataProvider {
    private int numEpochs;
    private int batchSize;
    private int rngSeed;

    public MnistDataProvider(int numEpochs, int batchSize) {
        this(numEpochs, batchSize, new Random().nextInt());
    }

    public MnistDataProvider(@JsonProperty(value="numEpochs") int numEpochs, @JsonProperty(value="batchSize") int batchSize, @JsonProperty(value="rngSeed") int rngSeed) {
        this.numEpochs = numEpochs;
        this.batchSize = batchSize;
        this.rngSeed = rngSeed;
    }

    public Object trainData(Map<String, Object> dataParameters) {
        try {
            return new MultipleEpochsIterator(this.numEpochs, (DataSetIterator)new MnistDataSetIterator(this.batchSize, true, this.rngSeed));
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public Object testData(Map<String, Object> dataParameters) {
        try {
            return new MnistDataSetIterator(this.batchSize, false, 12345);
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public Class<?> getDataType() {
        return DataSetIterator.class;
    }

    public int getNumEpochs() {
        return this.numEpochs;
    }

    public int getBatchSize() {
        return this.batchSize;
    }

    public int getRngSeed() {
        return this.rngSeed;
    }

    public void setNumEpochs(int numEpochs) {
        this.numEpochs = numEpochs;
    }

    public void setBatchSize(int batchSize) {
        this.batchSize = batchSize;
    }

    public void setRngSeed(int rngSeed) {
        this.rngSeed = rngSeed;
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof MnistDataProvider)) {
            return false;
        }
        MnistDataProvider other = (MnistDataProvider)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (this.getNumEpochs() != other.getNumEpochs()) {
            return false;
        }
        if (this.getBatchSize() != other.getBatchSize()) {
            return false;
        }
        return this.getRngSeed() == other.getRngSeed();
    }

    protected boolean canEqual(Object other) {
        return other instanceof MnistDataProvider;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        result = result * 59 + this.getNumEpochs();
        result = result * 59 + this.getBatchSize();
        result = result * 59 + this.getRngSeed();
        return result;
    }

    public String toString() {
        return "MnistDataProvider(numEpochs=" + this.getNumEpochs() + ", batchSize=" + this.getBatchSize() + ", rngSeed=" + this.getRngSeed() + ")";
    }

    public MnistDataProvider() {
    }
}

