/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.arbiter.scoring.graph;

import java.util.Map;
import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction;
import org.deeplearning4j.arbiter.scoring.RegressionValue;
import org.deeplearning4j.eval.RegressionEvaluation;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;

public class GraphTestSetRegressionScoreFunctionDataSet
implements ScoreFunction<ComputationGraph, DataSetIterator> {
    private final RegressionValue regressionValue;

    public GraphTestSetRegressionScoreFunctionDataSet(RegressionValue regressionValue) {
        this.regressionValue = regressionValue;
    }

    public double score(ComputationGraph model, DataProvider<DataSetIterator> dataProvider, Map<String, Object> dataParameters) {
        DataSetIterator testSet = (DataSetIterator)dataProvider.testData(dataParameters);
        RegressionEvaluation evaluation = new RegressionEvaluation(new String[0]);
        while (testSet.hasNext()) {
            DataSet next = (DataSet)testSet.next();
            INDArray labels = next.getLabels();
            if (next.hasMaskArrays()) {
                INDArray[] iNDArrayArray;
                INDArray[] iNDArrayArray2;
                INDArray fMask = next.getFeaturesMaskArray();
                INDArray lMask = next.getLabelsMaskArray();
                if (fMask == null) {
                    iNDArrayArray2 = null;
                } else {
                    INDArray[] iNDArrayArray3 = new INDArray[1];
                    iNDArrayArray2 = iNDArrayArray3;
                    iNDArrayArray3[0] = fMask;
                }
                INDArray[] fMasks = iNDArrayArray2;
                if (lMask == null) {
                    iNDArrayArray = null;
                } else {
                    INDArray[] iNDArrayArray4 = new INDArray[1];
                    iNDArrayArray = iNDArrayArray4;
                    iNDArrayArray4[0] = lMask;
                }
                INDArray[] lMasks = iNDArrayArray;
                model.setLayerMaskArrays(fMasks, lMasks);
                INDArray[] outputs = model.output(false, new INDArray[]{next.getFeatures()});
                if (lMasks != null && lMasks[0] != null) {
                    evaluation.evalTimeSeries(labels, outputs[0], lMasks[0]);
                } else {
                    evaluation.evalTimeSeries(labels, outputs[0]);
                }
                model.clearLayerMaskArrays();
                continue;
            }
            INDArray[] outputs = model.output(false, new INDArray[]{next.getFeatures()});
            if (labels.rank() == 3) {
                evaluation.evalTimeSeries(labels, outputs[0]);
                continue;
            }
            evaluation.eval(labels, outputs[0]);
        }
        double sum = 0.0;
        int nColumns = evaluation.numColumns();
        switch (this.regressionValue) {
            case MSE: {
                for (int j = 0; j < nColumns; ++j) {
                    sum += evaluation.meanSquaredError(j);
                }
                break;
            }
            case MAE: {
                for (int j = 0; j < nColumns; ++j) {
                    sum += evaluation.meanAbsoluteError(j);
                }
                break;
            }
            case RMSE: {
                for (int j = 0; j < nColumns; ++j) {
                    sum += evaluation.rootMeanSquaredError(j);
                }
                break;
            }
            case RSE: {
                for (int j = 0; j < nColumns; ++j) {
                    sum += evaluation.relativeSquaredError(j);
                }
                break;
            }
            case CorrCoeff: {
                for (int j = 0; j < nColumns; ++j) {
                    sum += evaluation.correlationR2(j);
                }
                sum /= (double)nColumns;
            }
        }
        return sum;
    }

    public boolean minimize() {
        return this.regressionValue != RegressionValue.CorrCoeff;
    }

    public String toString() {
        return "GraphTestSetRegressionScoreFunctionDataSet(type=" + (Object)((Object)this.regressionValue) + ")";
    }
}

