/*
 * Decompiled with CFR 0.152.
 */
package org.grouplens.lenskit.eval.metrics.predict;

import com.google.common.collect.ImmutableList;
import java.util.List;
import javax.annotation.Nonnull;
import org.grouplens.lenskit.eval.algorithm.AlgorithmInstance;
import org.grouplens.lenskit.eval.data.traintest.TTDataSet;
import org.grouplens.lenskit.eval.metrics.AbstractTestUserMetric;
import org.grouplens.lenskit.eval.metrics.TestUserMetricAccumulator;
import org.grouplens.lenskit.eval.traintest.TestUser;
import org.grouplens.lenskit.vectors.SparseVector;
import org.grouplens.lenskit.vectors.VectorEntry;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class MAEPredictMetric
extends AbstractTestUserMetric {
    private static final Logger logger = LoggerFactory.getLogger(MAEPredictMetric.class);
    private static final ImmutableList<String> COLUMNS = ImmutableList.of((Object)"MAE", (Object)"MAE.ByUser");
    private static final ImmutableList<String> USER_COLUMNS = ImmutableList.of((Object)"MAE");

    @Override
    public TestUserMetricAccumulator makeAccumulator(AlgorithmInstance algo, TTDataSet ds) {
        return new Accum();
    }

    @Override
    public List<String> getColumnLabels() {
        return COLUMNS;
    }

    @Override
    public List<String> getUserColumnLabels() {
        return USER_COLUMNS;
    }

    class Accum
    implements TestUserMetricAccumulator {
        private double totalError = 0.0;
        private double totalUserError = 0.0;
        private int nratings = 0;
        private int nusers = 0;

        Accum() {
        }

        @Override
        @Nonnull
        public Object[] evaluate(TestUser user) {
            SparseVector ratings = user.getTestRatings();
            SparseVector predictions = user.getPredictions();
            double err = 0.0;
            int n = 0;
            for (VectorEntry e : predictions.fast()) {
                if (Double.isNaN(e.getValue())) continue;
                err += Math.abs(e.getValue() - ratings.get(e.getKey()));
                ++n;
            }
            if (n > 0) {
                this.totalError += err;
                this.nratings += n;
                double errRate = err / (double)n;
                this.totalUserError += errRate;
                ++this.nusers;
                return new Object[]{errRate};
            }
            return new Object[1];
        }

        @Override
        @Nonnull
        public Object[] finalResults() {
            double v = this.totalError / (double)this.nratings;
            double uv = this.totalUserError / (double)this.nusers;
            logger.info("MAE: {}", (Object)v);
            return new Object[]{v, uv};
        }
    }
}

