/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.arbiter.scoring.graph;

import java.util.Map;
import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;

public abstract class BaseGraphTestSetEvaluationScoreFunction
implements ScoreFunction<ComputationGraph, MultiDataSetIterator> {
    protected Evaluation getEvaluation(ComputationGraph model, DataProvider<MultiDataSetIterator> dataProvider, Map<String, Object> dataParameters) {
        if (model.getNumOutputArrays() != 1) {
            throw new IllegalStateException("GraphSetSetAccuracyScoreFunction cannot be applied to ComputationGraphs with more than one output. NumOutputs = " + model.getNumOutputArrays());
        }
        MultiDataSetIterator testData = (MultiDataSetIterator)dataProvider.testData(dataParameters);
        Evaluation evaluation = new Evaluation();
        while (testData.hasNext()) {
            MultiDataSet next = (MultiDataSet)testData.next();
            if (next.hasMaskArrays()) {
                INDArray[] fMask = next.getFeaturesMaskArrays();
                INDArray[] lMask = next.getLabelsMaskArrays();
                model.setLayerMaskArrays(fMask, lMask);
                INDArray out = model.output(next.getFeatures())[0];
                if (lMask != null) {
                    evaluation.evalTimeSeries(next.getLabels(0), out, lMask[0]);
                } else {
                    evaluation.evalTimeSeries(next.getLabels(0), out);
                }
                model.clearLayerMaskArrays();
                continue;
            }
            INDArray out = model.output(false, next.getFeatures())[0];
            if (next.getLabels(0).rank() == 3) {
                evaluation.evalTimeSeries(next.getLabels(0), out);
                continue;
            }
            evaluation.eval(next.getLabels(0), out);
        }
        return evaluation;
    }

    public boolean minimize() {
        return false;
    }

    public abstract String toString();
}

