/*
 * Decompiled with CFR 0.152.
 */
package org.grouplens.lenskit.eval.traintest;

import com.google.common.base.Function;
import com.google.common.base.Preconditions;
import com.google.common.base.Supplier;
import com.google.common.base.Suppliers;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import com.google.common.io.Closer;
import java.io.Closeable;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import javax.inject.Provider;
import org.apache.commons.lang3.tuple.Pair;
import org.grouplens.lenskit.Recommender;
import org.grouplens.lenskit.data.snapshot.PreferenceSnapshot;
import org.grouplens.lenskit.eval.AbstractTask;
import org.grouplens.lenskit.eval.TaskExecutionException;
import org.grouplens.lenskit.eval.algorithm.AlgorithmInstance;
import org.grouplens.lenskit.eval.algorithm.ExternalAlgorithmInstance;
import org.grouplens.lenskit.eval.algorithm.LenskitAlgorithmInstance;
import org.grouplens.lenskit.eval.data.traintest.TTDataSet;
import org.grouplens.lenskit.eval.metrics.Metric;
import org.grouplens.lenskit.eval.metrics.TestUserMetric;
import org.grouplens.lenskit.eval.traintest.FunctionModelMetric;
import org.grouplens.lenskit.eval.traintest.FunctionMultiModelMetric;
import org.grouplens.lenskit.eval.traintest.ModelMetric;
import org.grouplens.lenskit.eval.traintest.SharedPreferenceSnapshot;
import org.grouplens.lenskit.eval.traintest.TrainTestEvalJob;
import org.grouplens.lenskit.eval.traintest.TrainTestJobException;
import org.grouplens.lenskit.symbols.Symbol;
import org.grouplens.lenskit.util.parallel.TaskGroupRunner;
import org.grouplens.lenskit.util.table.Table;
import org.grouplens.lenskit.util.table.TableBuilder;
import org.grouplens.lenskit.util.table.TableLayout;
import org.grouplens.lenskit.util.table.TableLayoutBuilder;
import org.grouplens.lenskit.util.table.writer.CSVWriter;
import org.grouplens.lenskit.util.table.writer.MultiplexedTableWriter;
import org.grouplens.lenskit.util.table.writer.TableWriter;
import org.grouplens.lenskit.util.table.writer.TableWriters;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class TrainTestEvalTask
extends AbstractTask<Table> {
    private static final Logger logger = LoggerFactory.getLogger(TrainTestEvalTask.class);
    private List<TTDataSet> dataSets = new LinkedList<TTDataSet>();
    private List<AlgorithmInstance> algorithms = new LinkedList<AlgorithmInstance>();
    private List<TestUserMetric> metrics = new LinkedList<TestUserMetric>();
    private List<Pair<Symbol, String>> predictChannels;
    private boolean isolate = false;
    private File outputFile;
    private File userOutputFile;
    private File predictOutputFile;
    private int numRecs = 5;
    private int commonColumnCount;
    private TableLayout outputLayout;
    private TableLayout userLayout;
    private TableLayout predictLayout;
    private TableWriter output;
    private TableBuilder outputInMemory;
    private TableWriter userOutput;
    private TableWriter predictOutput;
    private Map<String, Integer> dataColumns;
    private Map<String, Integer> algoColumns;
    private List<TestUserMetric> predictMetrics;
    private TableLayout masterLayout;
    private List<ModelMetric> modelMetrics = new LinkedList<ModelMetric>();

    public TrainTestEvalTask() {
        this("train-test");
    }

    public TrainTestEvalTask(String name) {
        super(name);
        this.predictChannels = new LinkedList<Pair<Symbol, String>>();
        this.outputFile = new File("train-test-results.csv");
    }

    public TrainTestEvalTask addDataset(TTDataSet source) {
        this.dataSets.add(source);
        return this;
    }

    public TrainTestEvalTask addAlgorithm(LenskitAlgorithmInstance algorithm) {
        this.algorithms.add(algorithm);
        return this;
    }

    public TrainTestEvalTask addExternalAlgorithm(ExternalAlgorithmInstance algorithm) {
        this.algorithms.add(algorithm);
        return this;
    }

    public TrainTestEvalTask addMetric(TestUserMetric metric) {
        this.metrics.add(metric);
        return this;
    }

    public TrainTestEvalTask addMetric(Class<? extends TestUserMetric> metricClass) throws IllegalAccessException, InstantiationException {
        return this.addMetric(metricClass.newInstance());
    }

    public TrainTestEvalTask addMultiMetric(File file, List<String> columns, Function<Recommender, List<List<Object>>> metric) {
        this.modelMetrics.add(new FunctionMultiModelMetric(file, columns, metric));
        return this;
    }

    public TrainTestEvalTask addMetric(List<String> columns, Function<Recommender, List<Object>> metric) {
        this.modelMetrics.add(new FunctionModelMetric(columns, metric));
        return this;
    }

    public TrainTestEvalTask addWritePredictionChannel(@Nonnull Symbol channelSym) {
        return this.addWritePredictionChannel(channelSym, null);
    }

    public TrainTestEvalTask addWritePredictionChannel(@Nonnull Symbol channelSym, @Nullable String label) {
        Preconditions.checkNotNull((Object)channelSym, (Object)"channel is null");
        if (label == null) {
            label = channelSym.getName();
        }
        Pair entry = Pair.of((Object)channelSym, (Object)label);
        this.predictChannels.add((Pair<Symbol, String>)entry);
        return this;
    }

    public TrainTestEvalTask setOutput(File file) {
        this.outputFile = file;
        return this;
    }

    public TrainTestEvalTask setOutput(String fn) {
        return this.setOutput(new File(fn));
    }

    public TrainTestEvalTask setUserOutput(File file) {
        this.userOutputFile = file;
        return this;
    }

    public TrainTestEvalTask setUserOutput(String fn) {
        return this.setUserOutput(new File(fn));
    }

    public TrainTestEvalTask setPredictOutput(File file) {
        this.predictOutputFile = file;
        return this;
    }

    public TrainTestEvalTask setPredictOutput(String fn) {
        return this.setPredictOutput(new File(fn));
    }

    public TrainTestEvalTask setIsolate(boolean iso) {
        this.isolate = iso;
        return this;
    }

    List<TTDataSet> dataSources() {
        return this.dataSets;
    }

    List<AlgorithmInstance> getAlgorithms() {
        return this.algorithms;
    }

    List<TestUserMetric> getMetrics() {
        return this.metrics;
    }

    List<Pair<Symbol, String>> getPredictionChannels() {
        return this.predictChannels;
    }

    File getOutput() {
        return this.outputFile;
    }

    File getPredictOutput() {
        return this.predictOutputFile;
    }

    public int getNumRecs() {
        return this.numRecs;
    }

    public TrainTestEvalTask setNumRecs(int numRecs) {
        this.numRecs = numRecs;
        return this;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public Table perform() throws TaskExecutionException {
        List<List<TrainTestEvalJob>> jobGroups = this.makeJobGroups();
        this.setupTableLayouts();
        logger.info("Starting evaluation");
        Closer closer = Closer.create();
        try {
            try {
                this.prepareEval(closer);
                try {
                    this.runEvaluations(jobGroups);
                }
                finally {
                    this.cleanUp();
                }
            }
            catch (Throwable th) {
                throw closer.rethrow(th, TaskExecutionException.class);
            }
            finally {
                closer.close();
            }
        }
        catch (IOException e) {
            throw new TaskExecutionException("I/O error", e);
        }
        return this.outputInMemory.build();
    }

    private void runEvaluations(List<List<TrainTestEvalJob>> jobGroups) throws TaskExecutionException {
        int nthreads = this.getProject().getConfig().getThreadCount();
        logger.info("Running evaluator with {} threads", (Object)nthreads);
        ExecutorService exec = Executors.newFixedThreadPool(nthreads);
        for (List<TrainTestEvalJob> group : jobGroups) {
            TaskGroupRunner runner = TaskGroupRunner.create((ExecutorService)exec);
            runner.submitAll(group);
            try {
                runner.waitForAll();
            }
            catch (ExecutionException e) {
                Throwable cause = e.getCause();
                if (cause instanceof TrainTestJobException) {
                    cause = cause.getCause();
                }
                throw new TaskExecutionException(cause);
            }
            catch (TrainTestJobException e) {
                throw new TaskExecutionException(e);
            }
        }
    }

    List<List<TrainTestEvalJob>> makeJobGroups() {
        ArrayList jobGroups = Lists.newArrayList();
        for (TTDataSet dataset : this.dataSets) {
            List<TrainTestEvalJob> jobs = this.makeJobs(dataset);
            jobGroups.add(jobs);
        }
        if (!this.isolate) {
            return Collections.singletonList(Lists.newArrayList((Iterable)Iterables.concat((Iterable)jobGroups)));
        }
        return jobGroups;
    }

    private List<TrainTestEvalJob> makeJobs(TTDataSet data) {
        ArrayList jobs = Lists.newArrayListWithCapacity((int)this.algorithms.size());
        Provider<PreferenceSnapshot> snap = SharedPreferenceSnapshot.provider(data);
        for (AlgorithmInstance algo : this.algorithms) {
            Function<TableWriter, TableWriter> prefix = this.prefixFunction(algo, data);
            TrainTestEvalJob job = new TrainTestEvalJob(algo, this.metrics, this.modelMetrics, this.predictChannels, data, snap, (Supplier<TableWriter>)Suppliers.compose(prefix, this.outputTableSupplier()), (Supplier<TableWriter>)Suppliers.compose(prefix, this.userTableSupplier()), (Supplier<TableWriter>)Suppliers.compose(prefix, this.predictTableSupplier()), this.numRecs);
            jobs.add(job);
        }
        return jobs;
    }

    private void setupTableLayouts() {
        TableLayoutBuilder master = new TableLayoutBuilder();
        this.layoutCommonColumns(master);
        this.masterLayout = master.build();
        this.commonColumnCount = master.getColumnCount();
        this.outputLayout = this.layoutAggregateOutput(master);
        this.userLayout = this.layoutUserTable(master);
        this.predictLayout = this.layoutPredictionTable(master);
        this.predictMetrics = this.metrics;
    }

    public TableLayout getMasterLayout() {
        return this.masterLayout;
    }

    private void layoutCommonColumns(TableLayoutBuilder master) {
        master.addColumn("Algorithm");
        this.dataColumns = new HashMap<String, Integer>();
        for (TTDataSet ds : this.dataSets) {
            for (String attr : ds.getAttributes().keySet()) {
                if (this.dataColumns.containsKey(attr)) continue;
                this.dataColumns.put(attr, master.getColumnCount());
                master.addColumn(attr);
            }
        }
        this.algoColumns = new HashMap<String, Integer>();
        for (AlgorithmInstance algo : this.algorithms) {
            for (String attr : algo.getAttributes().keySet()) {
                if (this.algoColumns.containsKey(attr)) continue;
                this.algoColumns.put(attr, master.getColumnCount());
                master.addColumn(attr);
            }
        }
    }

    private TableLayout layoutAggregateOutput(TableLayoutBuilder master) {
        TableLayoutBuilder output = master.clone();
        output.addColumn("BuildTime");
        output.addColumn("TestTime");
        for (ModelMetric modelMetric : this.modelMetrics) {
            for (String c : modelMetric.getColumnLabels()) {
                output.addColumn(c);
            }
        }
        for (TestUserMetric testUserMetric : this.metrics) {
            List<String> columnLabels = testUserMetric.getColumnLabels();
            if (columnLabels == null) continue;
            for (String c : columnLabels) {
                output.addColumn(c);
            }
        }
        return output.build();
    }

    private TableLayout layoutUserTable(TableLayoutBuilder master) {
        TableLayoutBuilder perUser = master.clone();
        perUser.addColumn("User");
        for (TestUserMetric ev : this.metrics) {
            List<String> userColumnLabels = ev.getUserColumnLabels();
            if (userColumnLabels == null) continue;
            for (String c : userColumnLabels) {
                perUser.addColumn(c);
            }
        }
        return perUser.build();
    }

    private TableLayout layoutPredictionTable(TableLayoutBuilder master) {
        TableLayoutBuilder eachPred = master.clone();
        eachPred.addColumn("User");
        eachPred.addColumn("Item");
        eachPred.addColumn("Rating");
        eachPred.addColumn("Prediction");
        for (Pair<Symbol, String> pair : this.predictChannels) {
            eachPred.addColumn((String)pair.getRight());
        }
        return eachPred.build();
    }

    private void prepareEval(Closer closer) throws IOException {
        logger.info("Starting evaluation");
        ArrayList<TableWriter> tableWriters = new ArrayList<TableWriter>();
        this.outputInMemory = new TableBuilder(this.outputLayout);
        tableWriters.add(this.outputInMemory);
        if (this.outputFile != null) {
            tableWriters.add((TableWriter)closer.register((Closeable)CSVWriter.open(this.outputFile, this.outputLayout)));
        }
        this.output = new MultiplexedTableWriter(this.outputLayout, tableWriters);
        if (this.userOutputFile != null) {
            this.userOutput = (TableWriter)closer.register((Closeable)CSVWriter.open(this.userOutputFile, this.userLayout));
        }
        if (this.predictOutputFile != null) {
            this.predictOutput = (TableWriter)closer.register((Closeable)CSVWriter.open(this.predictOutputFile, this.predictLayout));
        }
        for (Metric metric : Iterables.concat(this.predictMetrics, this.modelMetrics)) {
            metric.startEvaluation(this);
        }
    }

    private void cleanUp() throws IOException {
        for (Metric metric : Iterables.concat(this.predictMetrics, this.modelMetrics)) {
            metric.finishEvaluation();
        }
        if (this.output == null) {
            throw new IllegalStateException("evaluation not running");
        }
        logger.info("Evaluation finished");
        this.output = null;
        this.userOutput = null;
        this.predictOutput = null;
    }

    @Nonnull
    Supplier<TableWriter> outputTableSupplier() {
        return new Supplier<TableWriter>(){

            public TableWriter get() {
                Preconditions.checkState((TrainTestEvalTask.this.output != null ? 1 : 0) != 0, (Object)"evaluation not running");
                return TrainTestEvalTask.this.output;
            }
        };
    }

    @Nonnull
    Supplier<TableWriter> predictTableSupplier() {
        return new Supplier<TableWriter>(){

            public TableWriter get() {
                return TrainTestEvalTask.this.predictOutput;
            }
        };
    }

    @Nonnull
    Supplier<TableWriter> userTableSupplier() {
        return new Supplier<TableWriter>(){

            public TableWriter get() {
                return TrainTestEvalTask.this.userOutput;
            }
        };
    }

    public Function<TableWriter, TableWriter> prefixFunction(final AlgorithmInstance algorithm, final TTDataSet dataSet) {
        return new Function<TableWriter, TableWriter>(){

            public TableWriter apply(TableWriter base) {
                return TrainTestEvalTask.this.prefixTable(base, algorithm, dataSet);
            }
        };
    }

    public TableWriter prefixTable(TableWriter base, AlgorithmInstance algorithm, TTDataSet dataSet) {
        int idx;
        if (base == null) {
            return null;
        }
        Object[] prefix = new Object[this.commonColumnCount];
        prefix[0] = algorithm.getName();
        for (Map.Entry<String, Object> attr : dataSet.getAttributes().entrySet()) {
            idx = this.dataColumns.get(attr.getKey());
            prefix[idx] = attr.getValue();
        }
        for (Map.Entry<String, Object> attr : algorithm.getAttributes().entrySet()) {
            idx = this.algoColumns.get(attr.getKey());
            prefix[idx] = attr.getValue();
        }
        return TableWriters.prefixed(base, prefix);
    }
}

