/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.timeseries.eval;

import java.util.List;
import weka.classifiers.evaluation.NumericPrediction;
import weka.classifiers.timeseries.eval.ErrorModule;
import weka.core.Instance;
import weka.core.Utils;

public class RRSEModule
extends ErrorModule {
    protected double[] m_previousActual;
    protected double[] m_sumOfSE;
    protected RRSEModule m_relativeRRSE;
    protected static final double SMALL = 1.0E-6;

    @Override
    public void reset() {
        super.reset();
        this.m_previousActual = new double[this.m_targetFieldNames.size()];
        this.m_sumOfSE = new double[this.m_targetFieldNames.size()];
        for (int i = 0; i < this.m_targetFieldNames.size(); ++i) {
            this.m_previousActual[i] = Utils.missingValue();
            this.m_sumOfSE[i] = 0.0;
        }
    }

    public void setRelativeRRSEModule(RRSEModule relative) {
        this.m_relativeRRSE = relative;
    }

    public double[] getPreviousActual() {
        return this.m_previousActual;
    }

    @Override
    public String getEvalName() {
        return "RRSE";
    }

    @Override
    public String getDescription() {
        return "Root relative squared error";
    }

    @Override
    public String getDefinition() {
        return "sqrt(sum((predicted - actual)^2) / N) / sqrt(sum(previous_target - actual)^2) / N)";
    }

    protected void evaluatePredictionForTargetForInstance(int targetIndex, NumericPrediction forecast, double actualValue) {
        double predictedValue = forecast.predicted();
        double[][] intervals = forecast.predictionIntervals();
        NumericPrediction pred = new NumericPrediction(actualValue, predictedValue, 1.0, intervals);
        ((List)this.m_predictions.get(targetIndex)).add(pred);
        int n = targetIndex;
        this.m_counts[n] = this.m_counts[n] + 1.0;
    }

    @Override
    public void evaluateForInstance(List<NumericPrediction> forecasts, Instance inst) throws Exception {
        for (int i = 0; i < this.m_targetFieldNames.size(); ++i) {
            double actualValue = this.getTargetValue((String)this.m_targetFieldNames.get(i), inst);
            if (this.m_relativeRRSE != null) {
                this.m_previousActual = this.m_relativeRRSE.getPreviousActual();
            }
            if (this.m_relativeRRSE == null && Utils.isMissingValue((double)this.m_previousActual[i])) {
                this.m_previousActual[i] = actualValue;
                continue;
            }
            if (!Utils.isMissingValue((double)actualValue) && !Utils.isMissingValue((double)this.m_previousActual[i])) {
                this.evaluatePredictionForTargetForInstance(i, forecasts.get(i), actualValue);
                int n = i;
                this.m_sumOfSE[n] = this.m_sumOfSE[n] + (this.m_previousActual[i] - actualValue) * (this.m_previousActual[i] - actualValue);
            }
            if (this.m_relativeRRSE != null) continue;
            this.m_previousActual[i] = actualValue;
        }
    }

    @Override
    public double[] calculateMeasure() throws Exception {
        int i;
        double[] result = new double[this.m_targetFieldNames.size()];
        for (i = 0; i < result.length; ++i) {
            result[i] = Utils.missingValue();
        }
        for (i = 0; i < this.m_targetFieldNames.size(); ++i) {
            double sumSE = 0.0;
            double count = 0.0;
            List preds = (List)this.m_predictions.get(i);
            for (NumericPrediction p : preds) {
                if (Utils.isMissingValue((double)p.error())) continue;
                sumSE += p.error() * p.error();
                count += 1.0;
            }
            if (this.m_sumOfSE[i] == 0.0) {
                this.m_sumOfSE[i] = 1.0E-6;
            }
            if (count == 0.0) {
                result[i] = Utils.missingValue();
                continue;
            }
            double rootMSEPrev = Math.sqrt(sumSE / count);
            double rootMSE = Math.sqrt(this.m_sumOfSE[i] / count);
            result[i] = rootMSEPrev / rootMSE * 100.0;
        }
        return result;
    }
}

