/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.arbiter.scoring.impl;

import org.deeplearning4j.arbiter.scoring.RegressionValue;
import org.deeplearning4j.arbiter.scoring.impl.BaseNetScoreFunction;
import org.deeplearning4j.arbiter.scoring.util.ScoreUtil;
import org.deeplearning4j.eval.RegressionEvaluation;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;

public class TestSetRegressionScoreFunction
extends BaseNetScoreFunction {
    private final RegressionValue regressionValue;

    public TestSetRegressionScoreFunction(RegressionValue regressionValue) {
        this.regressionValue = regressionValue;
    }

    public boolean minimize() {
        return this.regressionValue != RegressionValue.CorrCoeff;
    }

    public String toString() {
        return "TestSetRegressionScoreFunction(type=" + (Object)((Object)this.regressionValue) + ")";
    }

    @Override
    public double score(MultiLayerNetwork net, DataSetIterator iterator) {
        RegressionEvaluation e = net.evaluateRegression(iterator);
        return ScoreUtil.getScoreFromRegressionEval(e, this.regressionValue);
    }

    @Override
    public double score(MultiLayerNetwork net, MultiDataSetIterator iterator) {
        throw new UnsupportedOperationException("Cannot evaluate MultiLayerNetwork on MultiDataSetIterator");
    }

    @Override
    public double score(ComputationGraph graph, DataSetIterator iterator) {
        RegressionEvaluation e = graph.evaluateRegression(iterator);
        return ScoreUtil.getScoreFromRegressionEval(e, this.regressionValue);
    }

    @Override
    public double score(ComputationGraph graph, MultiDataSetIterator iterator) {
        RegressionEvaluation e = graph.evaluateRegression(iterator);
        return ScoreUtil.getScoreFromRegressionEval(e, this.regressionValue);
    }
}

