/*
 * Decompiled with CFR 0.152.
 */
package org.grouplens.lenskit.eval.traintest;

import com.google.common.base.Supplier;
import com.google.common.base.Suppliers;
import com.google.common.collect.Lists;
import com.google.common.io.Closer;
import it.unimi.dsi.fastutil.longs.LongSet;
import java.io.Closeable;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import javax.annotation.Nonnull;
import javax.inject.Provider;
import org.apache.commons.lang3.time.StopWatch;
import org.apache.commons.lang3.tuple.Pair;
import org.grouplens.lenskit.RecommenderBuildException;
import org.grouplens.lenskit.cursors.Cursor;
import org.grouplens.lenskit.data.dao.UserEventDAO;
import org.grouplens.lenskit.data.event.Event;
import org.grouplens.lenskit.data.history.RatingVectorUserHistorySummarizer;
import org.grouplens.lenskit.data.history.UserHistory;
import org.grouplens.lenskit.data.snapshot.PreferenceSnapshot;
import org.grouplens.lenskit.eval.ExecutionInfo;
import org.grouplens.lenskit.eval.algorithm.AlgorithmInstance;
import org.grouplens.lenskit.eval.algorithm.RecommenderInstance;
import org.grouplens.lenskit.eval.data.traintest.TTDataSet;
import org.grouplens.lenskit.eval.metrics.TestUserMetric;
import org.grouplens.lenskit.eval.metrics.TestUserMetricAccumulator;
import org.grouplens.lenskit.eval.traintest.ModelMetric;
import org.grouplens.lenskit.eval.traintest.TestUser;
import org.grouplens.lenskit.eval.traintest.TrainTestJobException;
import org.grouplens.lenskit.scored.ScoredId;
import org.grouplens.lenskit.symbols.Symbol;
import org.grouplens.lenskit.util.table.writer.TableWriter;
import org.grouplens.lenskit.vectors.SparseVector;
import org.grouplens.lenskit.vectors.VectorEntry;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

class TrainTestEvalJob
implements Runnable {
    private static final Logger logger = LoggerFactory.getLogger(TrainTestEvalJob.class);
    private final int numRecs;
    @Nonnull
    private final AlgorithmInstance algorithm;
    @Nonnull
    private final List<TestUserMetric> evaluators;
    @Nonnull
    private final List<ModelMetric> modelMetrics;
    @Nonnull
    private final List<Pair<Symbol, String>> channels;
    @Nonnull
    private final TTDataSet data;
    @Nonnull
    private final Supplier<TableWriter> outputSupplier;
    @Nonnull
    private final Supplier<TableWriter> userOutputSupplier;
    @Nonnull
    private final Supplier<TableWriter> predictOutputSupplier;
    private final Provider<PreferenceSnapshot> snapshot;

    public TrainTestEvalJob(@Nonnull AlgorithmInstance algo, @Nonnull List<TestUserMetric> evals, @Nonnull List<ModelMetric> mMetrics, @Nonnull List<Pair<Symbol, String>> chans, @Nonnull TTDataSet ds, Provider<PreferenceSnapshot> snap, @Nonnull Supplier<TableWriter> out, @Nonnull Supplier<TableWriter> userOut, @Nonnull Supplier<TableWriter> predOut, int nrecs) {
        this.algorithm = algo;
        this.evaluators = evals;
        this.modelMetrics = mMetrics;
        this.channels = chans;
        this.data = ds;
        this.snapshot = snap;
        this.outputSupplier = out;
        this.userOutputSupplier = userOut;
        this.predictOutputSupplier = predOut;
        this.numRecs = nrecs;
    }

    @Override
    public void run() {
        try {
            this.runEvaluation();
        }
        catch (Exception e) {
            throw new TrainTestJobException(e);
        }
    }

    private void runEvaluation() throws IOException, RecommenderBuildException {
        Closer closer = Closer.create();
        try {
            TableWriter predictTable;
            TableWriter userTable = (TableWriter)this.userOutputSupplier.get();
            if (userTable != null) {
                closer.register((Closeable)userTable);
            }
            if ((predictTable = (TableWriter)this.predictOutputSupplier.get()) != null) {
                closer.register((Closeable)predictTable);
            }
            ArrayList outputRow = Lists.newArrayList();
            ExecutionInfo execInfo = this.buildExecInfo();
            logger.info("Building {}", (Object)this.algorithm.getName());
            StopWatch buildTimer = new StopWatch();
            buildTimer.start();
            RecommenderInstance rec = this.algorithm.makeTestableRecommender(this.data, this.snapshot, execInfo);
            buildTimer.stop();
            logger.info("Built {} in {}", (Object)this.algorithm.getName(), (Object)buildTimer);
            logger.info("Measuring {}", (Object)this.algorithm.getName());
            for (ModelMetric metric : this.modelMetrics) {
                outputRow.addAll(metric.measureAlgorithm(this.algorithm, this.data, rec.getRecommender()));
            }
            logger.info("Testing {}", (Object)this.algorithm.getName());
            StopWatch testTimer = new StopWatch();
            testTimer.start();
            ArrayList<TestUserMetricAccumulator> evalAccums = new ArrayList<TestUserMetricAccumulator>(this.evaluators.size());
            ArrayList<Object> userRow = new ArrayList<Object>();
            UserEventDAO testUsers = this.data.getTestData().getUserEventDAO();
            for (TestUserMetric eval : this.evaluators) {
                TestUserMetricAccumulator accum = eval.makeAccumulator(this.algorithm, this.data);
                evalAccums.add(accum);
            }
            Cursor userProfiles = (Cursor)closer.register((Closeable)testUsers.streamEventsByUser());
            for (UserHistory p : userProfiles) {
                assert (userRow.isEmpty());
                userRow.add(p.getUserId());
                long uid = p.getUserId();
                LongSet testItems = p.itemSet();
                PredictionSupplier preds = new PredictionSupplier(rec, uid, testItems);
                RecommendationSupplier recs = new RecommendationSupplier(rec, uid, testItems);
                HistorySupplier hist = new HistorySupplier(rec.getUserEventDAO(), uid);
                Supplier testHist = Suppliers.ofInstance((Object)p);
                TestUser test = new TestUser(uid, hist, (Supplier<UserHistory<Event>>)testHist, preds, recs);
                for (TestUserMetricAccumulator accum : evalAccums) {
                    Object[] ures = accum.evaluate(test);
                    if (ures == null) continue;
                    userRow.addAll(Arrays.asList(ures));
                }
                if (userTable != null) {
                    try {
                        userTable.writeRow(userRow);
                    }
                    catch (IOException e) {
                        throw new RuntimeException("error writing user row", e);
                    }
                }
                userRow.clear();
                if (predictTable == null) continue;
                this.writePredictions(predictTable, uid, RatingVectorUserHistorySummarizer.makeRatingVector((UserHistory)p), test.getPredictions());
            }
            testTimer.stop();
            logger.info("Tested {} in {}", (Object)this.algorithm.getName(), (Object)testTimer);
            this.writeOutput(buildTimer, testTimer, outputRow, evalAccums);
        }
        catch (Throwable th) {
            throw closer.rethrow(th, RecommenderBuildException.class);
        }
        finally {
            closer.close();
        }
    }

    private ExecutionInfo buildExecInfo() {
        ExecutionInfo.Builder bld = new ExecutionInfo.Builder();
        bld.setAlgoName(this.algorithm.getName()).setAlgoAttributes(this.algorithm.getAttributes()).setDataName(this.data.getName()).setDataAttributes(this.data.getAttributes());
        return bld.build();
    }

    private void writePredictions(TableWriter predictTable, long uid, SparseVector ratings, SparseVector predictions) throws IOException {
        int ncols = predictTable.getLayout().getColumnCount();
        Object[] row = new String[ncols];
        row[0] = Long.toString(uid);
        for (VectorEntry e : ratings.fast()) {
            long iid = e.getKey();
            row[1] = Long.toString(iid);
            row[2] = Double.toString(e.getValue());
            row[3] = predictions.containsKey(iid) ? Double.toString(predictions.get(iid)) : null;
            int i = 4;
            for (Pair<Symbol, String> pair : this.channels) {
                Symbol c = (Symbol)pair.getLeft();
                row[i] = predictions.hasChannelVector(c) && predictions.getChannelVector(c).containsKey(iid) ? Double.toString(predictions.getChannelVector(c).get(iid)) : null;
                ++i;
            }
            predictTable.writeRow(row);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void writeOutput(StopWatch build, StopWatch test, List<Object> measures, List<TestUserMetricAccumulator> accums) throws IOException {
        TableWriter output = (TableWriter)this.outputSupplier.get();
        try {
            Object[] row = new Object[output.getLayout().getColumnCount()];
            row[0] = build.getTime();
            row[1] = test.getTime();
            int col = 2;
            Iterator<Object> i$ = measures.iterator();
            while (i$.hasNext()) {
                Object o;
                row[col] = o = i$.next();
                ++col;
            }
            for (TestUserMetricAccumulator acc : accums) {
                Object[] ar = acc.finalResults();
                if (ar == null) continue;
                int n = ar.length;
                System.arraycopy(ar, 0, row, col, n);
                col += n;
            }
            output.writeRow(row);
        }
        finally {
            output.close();
        }
    }

    private class HistorySupplier
    implements Supplier<UserHistory<Event>> {
        private final UserEventDAO userEventDAO;
        private final long user;

        public HistorySupplier(UserEventDAO dao, long id) {
            this.userEventDAO = dao;
            this.user = id;
        }

        public UserHistory<Event> get() {
            return this.userEventDAO.getEventsForUser(this.user);
        }
    }

    private class RecommendationSupplier
    implements Supplier<List<ScoredId>> {
        private final RecommenderInstance recommender;
        private final long user;
        private final LongSet items;

        public RecommendationSupplier(RecommenderInstance rec, long id, LongSet is) {
            this.recommender = rec;
            this.user = id;
            this.items = is;
        }

        public List<ScoredId> get() {
            if (this.recommender == null) {
                throw new IllegalArgumentException("cannot compute recommendations without a recommender");
            }
            List<ScoredId> recs = this.recommender.getRecommendations(this.user, this.items, TrainTestEvalJob.this.numRecs);
            if (recs == null) {
                throw new IllegalArgumentException("no recommendations");
            }
            return recs;
        }
    }

    private class PredictionSupplier
    implements Supplier<SparseVector> {
        private final RecommenderInstance predictor;
        private final long user;
        private final LongSet items;

        public PredictionSupplier(RecommenderInstance pred, long id, LongSet is) {
            this.predictor = pred;
            this.user = id;
            this.items = is;
        }

        public SparseVector get() {
            if (this.predictor == null) {
                throw new IllegalArgumentException("cannot compute predictions without a predictor");
            }
            SparseVector preds = this.predictor.getPredictions(this.user, this.items);
            if (preds == null) {
                throw new IllegalArgumentException("no predictions");
            }
            return preds;
        }
    }
}

