/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.arbiter.scoring.impl;

import java.util.Arrays;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;

public abstract class BaseNetScoreFunction
implements ScoreFunction {
    public double score(Object model, DataProvider dataProvider, Map<String, Object> dataParameters) {
        Object testData = dataProvider.testData(dataParameters);
        if (model instanceof MultiLayerNetwork) {
            if (testData instanceof DataSetIterator) {
                return this.score((MultiLayerNetwork)model, (DataSetIterator)testData);
            }
            return this.score((MultiLayerNetwork)model, (MultiDataSetIterator)testData);
        }
        if (testData instanceof DataSetIterator) {
            return this.score((ComputationGraph)model, (DataSetIterator)testData);
        }
        return this.score((ComputationGraph)model, (MultiDataSetIterator)testData);
    }

    public List<Class<?>> getSupportedModelTypes() {
        return Arrays.asList(MultiLayerNetwork.class, ComputationGraph.class);
    }

    public List<Class<?>> getSupportedDataTypes() {
        return Arrays.asList(DataSetIterator.class, MultiDataSetIterator.class);
    }

    public abstract double score(MultiLayerNetwork var1, DataSetIterator var2);

    public abstract double score(MultiLayerNetwork var1, MultiDataSetIterator var2);

    public abstract double score(ComputationGraph var1, DataSetIterator var2);

    public abstract double score(ComputationGraph var1, MultiDataSetIterator var2);
}

