/*
 * Decompiled with CFR 0.152.
 */
package org.grouplens.lenskit.eval.metrics.predict;

import com.google.common.collect.ImmutableList;
import java.util.List;
import javax.annotation.Nonnull;
import org.grouplens.lenskit.eval.algorithm.AlgorithmInstance;
import org.grouplens.lenskit.eval.data.traintest.TTDataSet;
import org.grouplens.lenskit.eval.metrics.AbstractTestUserMetric;
import org.grouplens.lenskit.eval.metrics.TestUserMetricAccumulator;
import org.grouplens.lenskit.eval.traintest.TestUser;
import org.grouplens.lenskit.vectors.SparseVector;
import org.grouplens.lenskit.vectors.VectorEntry;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class RMSEPredictMetric
extends AbstractTestUserMetric {
    private static final Logger logger = LoggerFactory.getLogger(RMSEPredictMetric.class);
    private static final ImmutableList<String> COLUMNS = ImmutableList.of((Object)"RMSE.ByRating", (Object)"RMSE.ByUser");
    private static final ImmutableList<String> USER_COLUMNS = ImmutableList.of((Object)"RMSE");

    @Override
    public TestUserMetricAccumulator makeAccumulator(AlgorithmInstance algo, TTDataSet ds) {
        return new Accum();
    }

    @Override
    public List<String> getColumnLabels() {
        return COLUMNS;
    }

    @Override
    public List<String> getUserColumnLabels() {
        return USER_COLUMNS;
    }

    class Accum
    implements TestUserMetricAccumulator {
        private double sse = 0.0;
        private double totalRMSE = 0.0;
        private int nratings = 0;
        private int nusers = 0;

        Accum() {
        }

        @Override
        @Nonnull
        public Object[] evaluate(TestUser user) {
            SparseVector ratings = user.getTestRatings();
            SparseVector predictions = user.getPredictions();
            double usse = 0.0;
            int n = 0;
            for (VectorEntry e : predictions.fast()) {
                if (Double.isNaN(e.getValue())) continue;
                double err = e.getValue() - ratings.get(e.getKey());
                usse += err * err;
                ++n;
            }
            this.sse += usse;
            this.nratings += n;
            if (n > 0) {
                double rmse = Math.sqrt(usse / (double)n);
                this.totalRMSE += rmse;
                ++this.nusers;
                return new Object[]{rmse};
            }
            return new Object[1];
        }

        @Override
        @Nonnull
        public Object[] finalResults() {
            double v = Math.sqrt(this.sse / (double)this.nratings);
            logger.info("RMSE: {}", (Object)v);
            return new Object[]{v, this.totalRMSE / (double)this.nusers};
        }
    }
}

