/*
 * Decompiled with CFR 0.152.
 */
package org.grouplens.lenskit.eval.data.crossfold;

import com.google.common.collect.Lists;
import com.google.common.io.Closer;
import it.unimi.dsi.fastutil.longs.Long2IntMap;
import it.unimi.dsi.fastutil.longs.Long2IntOpenHashMap;
import it.unimi.dsi.fastutil.longs.LongArrayList;
import it.unimi.dsi.fastutil.longs.LongCollection;
import it.unimi.dsi.fastutil.longs.LongList;
import it.unimi.dsi.fastutil.longs.LongListIterator;
import it.unimi.dsi.fastutil.longs.LongLists;
import java.io.Closeable;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import org.grouplens.lenskit.cursors.Cursor;
import org.grouplens.lenskit.cursors.Cursors;
import org.grouplens.lenskit.data.dao.UserDAO;
import org.grouplens.lenskit.data.event.Rating;
import org.grouplens.lenskit.data.history.UserHistory;
import org.grouplens.lenskit.data.pref.Preference;
import org.grouplens.lenskit.eval.AbstractTask;
import org.grouplens.lenskit.eval.TaskExecutionException;
import org.grouplens.lenskit.eval.data.CSVDataSourceBuilder;
import org.grouplens.lenskit.eval.data.DataSource;
import org.grouplens.lenskit.eval.data.crossfold.FractionPartition;
import org.grouplens.lenskit.eval.data.crossfold.Holdout;
import org.grouplens.lenskit.eval.data.crossfold.HoldoutNPartition;
import org.grouplens.lenskit.eval.data.crossfold.Order;
import org.grouplens.lenskit.eval.data.crossfold.PartitionAlgorithm;
import org.grouplens.lenskit.eval.data.crossfold.RandomOrder;
import org.grouplens.lenskit.eval.data.crossfold.RetainNPartition;
import org.grouplens.lenskit.eval.data.traintest.GenericTTDataBuilder;
import org.grouplens.lenskit.eval.data.traintest.TTDataSet;
import org.grouplens.lenskit.util.io.UpToDateChecker;
import org.grouplens.lenskit.util.table.writer.CSVWriter;
import org.grouplens.lenskit.util.table.writer.TableWriter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CrossfoldTask
extends AbstractTask<List<TTDataSet>> {
    private static final Logger logger = LoggerFactory.getLogger(CrossfoldTask.class);
    private DataSource source;
    private int partitionCount = 5;
    private String trainFilePattern;
    private String testFilePattern;
    private Order<Rating> order = new RandomOrder<Rating>();
    private PartitionAlgorithm<Rating> partition = new HoldoutNPartition<Rating>(10);
    private boolean isForced;
    private boolean splitUsers = true;
    private boolean cacheOutput = true;

    public CrossfoldTask() {
        super(null);
    }

    public CrossfoldTask(String n) {
        super(n);
    }

    public CrossfoldTask setPartitions(int partition) {
        this.partitionCount = partition;
        return this;
    }

    public CrossfoldTask setTrain(String pat) {
        this.trainFilePattern = pat;
        return this;
    }

    public CrossfoldTask setTest(String pat) {
        this.testFilePattern = pat;
        return this;
    }

    public CrossfoldTask setOrder(Order<Rating> o) {
        this.order = o;
        return this;
    }

    public CrossfoldTask setHoldout(int n) {
        this.partition = new HoldoutNPartition<Rating>(n);
        return this;
    }

    @Deprecated
    public CrossfoldTask setHoldout(double f) {
        this.partition = new FractionPartition<Rating>(f);
        return this;
    }

    public CrossfoldTask setRetain(int n) {
        this.partition = new RetainNPartition<Rating>(n);
        return this;
    }

    public CrossfoldTask setHoldoutFraction(double f) {
        this.partition = new FractionPartition<Rating>(f);
        return this;
    }

    public CrossfoldTask setSource(DataSource source) {
        this.source = source;
        return this;
    }

    public CrossfoldTask setForce(boolean force) {
        this.isForced = force;
        return this;
    }

    public void setSplitUsers(boolean splitUsers) {
        this.splitUsers = splitUsers;
    }

    public CrossfoldTask setCache(boolean on) {
        this.cacheOutput = on;
        return this;
    }

    @Override
    public String getName() {
        String name = super.getName();
        if (name == null) {
            name = this.source.getName();
        }
        return name;
    }

    public String getTrainPattern() {
        if (this.trainFilePattern != null) {
            return this.trainFilePattern;
        }
        StringBuilder sb = new StringBuilder();
        String dir = this.getProject().getConfig().getDataDir();
        if (dir == null) {
            dir = ".";
        }
        return sb.append(dir).append(File.separator).append(this.getName()).append("-crossfold").append(File.separator).append("train.%d.csv").toString();
    }

    public String getTestPattern() {
        if (this.testFilePattern != null) {
            return this.testFilePattern;
        }
        StringBuilder sb = new StringBuilder();
        String dir = this.getProject().getConfig().getDataDir();
        if (dir == null) {
            dir = ".";
        }
        return sb.append(dir).append(File.separator).append(this.getName()).append("-crossfold").append(File.separator).append("test.%d.csv").toString();
    }

    public DataSource getSource() {
        return this.source;
    }

    public int getPartitionCount() {
        return this.partitionCount;
    }

    public Holdout getHoldout() {
        return new Holdout(this.order, this.partition);
    }

    public boolean getForce() {
        return this.isForced || this.getProject().getConfig().force();
    }

    public boolean getSplitUsers() {
        return this.splitUsers;
    }

    @Override
    public List<TTDataSet> perform() throws TaskExecutionException {
        if (!this.getForce()) {
            UpToDateChecker check = new UpToDateChecker();
            check.addInput(this.source.lastModified());
            for (File f : this.getFiles(this.getTrainPattern())) {
                check.addOutput(f);
            }
            for (File f : this.getFiles(this.getTestPattern())) {
                check.addOutput(f);
            }
            if (check.isUpToDate()) {
                logger.info("crossfold {} up to date", (Object)this.getName());
                return this.getTTFiles();
            }
        }
        try {
            this.createTTFiles();
        }
        catch (IOException ex) {
            throw new TaskExecutionException("Error writing data sets", ex);
        }
        return this.getTTFiles();
    }

    protected File[] getFiles(String pattern) {
        File[] files = new File[this.partitionCount];
        for (int i = 0; i < this.partitionCount; ++i) {
            files[i] = new File(String.format(pattern, i));
        }
        return files;
    }

    protected void createTTFiles() throws IOException {
        File[] trainFiles = this.getFiles(this.getTrainPattern());
        File[] testFiles = this.getFiles(this.getTestPattern());
        TableWriter[] trainWriters = new TableWriter[this.partitionCount];
        TableWriter[] testWriters = new TableWriter[this.partitionCount];
        Closer closer = Closer.create();
        try {
            for (int i = 0; i < this.partitionCount; ++i) {
                File train = trainFiles[i];
                File test = testFiles[i];
                trainWriters[i] = (TableWriter)closer.register((Closeable)CSVWriter.open(train, null));
                testWriters[i] = (TableWriter)closer.register((Closeable)CSVWriter.open(test, null));
            }
            if (this.getSplitUsers()) {
                this.writeTTFilesByUsers(trainWriters, testWriters);
            } else {
                this.writeTTFilesByRatings(trainWriters, testWriters);
            }
        }
        catch (Throwable th) {
            throw closer.rethrow(th);
        }
        finally {
            closer.close();
        }
    }

    protected void writeTTFilesByUsers(TableWriter[] trainWriters, TableWriter[] testWriters) throws TaskExecutionException {
        logger.info("splitting data source {} to {} partitions by users", (Object)this.getName(), (Object)this.partitionCount);
        Long2IntMap splits = this.splitUsers(this.source.getUserDAO());
        Cursor historyCursor = this.source.getUserEventDAO().streamEventsByUser();
        Holdout mode = this.getHoldout();
        try {
            for (UserHistory history : historyCursor) {
                int foldNum = splits.get(history.getUserId());
                ArrayList<Rating> ratings = new ArrayList<Rating>((Collection<Rating>)history.filter(Rating.class));
                int p = mode.partition(ratings, this.getProject().getRandom());
                int n = ratings.size();
                for (int f = 0; f < this.partitionCount; ++f) {
                    if (f == foldNum) {
                        int j;
                        for (j = 0; j < p; ++j) {
                            this.writeRating(trainWriters[f], (Rating)ratings.get(j));
                        }
                        for (j = p; j < n; ++j) {
                            this.writeRating(testWriters[f], (Rating)ratings.get(j));
                        }
                        continue;
                    }
                    for (Rating rating : ratings) {
                        this.writeRating(trainWriters[f], rating);
                    }
                }
            }
        }
        catch (IOException e) {
            throw new TaskExecutionException("Error writing to the train test files", e);
        }
        finally {
            historyCursor.close();
        }
    }

    protected void writeTTFilesByRatings(TableWriter[] trainWriters, TableWriter[] testWriters) throws TaskExecutionException {
        logger.info("splitting data source {} to {} partitions by ratings", (Object)this.getName(), (Object)this.partitionCount);
        ArrayList ratings = Cursors.makeList((Cursor)this.source.getEventDAO().streamEvents(Rating.class));
        Collections.shuffle(ratings);
        try {
            int n = ratings.size();
            for (int i = 0; i < n; ++i) {
                for (int f = 0; f < this.partitionCount; ++f) {
                    int foldNum = i % this.partitionCount;
                    if (f == foldNum) {
                        this.writeRating(testWriters[f], (Rating)ratings.get(i));
                        continue;
                    }
                    this.writeRating(trainWriters[f], (Rating)ratings.get(i));
                }
            }
        }
        catch (IOException e) {
            throw new TaskExecutionException("Error writing to the train test files", e);
        }
    }

    protected void writeRating(TableWriter writer, Rating rating) throws IOException {
        Preference pref = rating.getPreference();
        writer.writeRow(Lists.newArrayList((Object[])new String[]{Long.toString(rating.getUserId()), Long.toString(rating.getItemId()), pref != null ? Double.toString(pref.getValue()) : "NaN", Long.toString(rating.getTimestamp())}));
    }

    protected Long2IntMap splitUsers(UserDAO dao) {
        Long2IntOpenHashMap userMap = new Long2IntOpenHashMap();
        LongArrayList users = new LongArrayList((LongCollection)dao.getUserIds());
        LongLists.shuffle((LongList)users, (Random)this.getProject().getRandom());
        LongListIterator iter = users.listIterator();
        while (iter.hasNext()) {
            int idx = iter.nextIndex();
            long user = iter.nextLong();
            userMap.put(user, idx % this.partitionCount);
        }
        logger.info("Partitioned {} users", (Object)userMap.size());
        return userMap;
    }

    public List<TTDataSet> getTTFiles() {
        ArrayList<TTDataSet> dataSets = new ArrayList<TTDataSet>(this.partitionCount);
        File[] trainFiles = this.getFiles(this.getTrainPattern());
        File[] testFiles = this.getFiles(this.getTestPattern());
        for (int i = 0; i < this.partitionCount; ++i) {
            CSVDataSourceBuilder trainBuilder = new CSVDataSourceBuilder().setDomain(this.source.getPreferenceDomain()).setCache(this.cacheOutput).setFile(trainFiles[i]);
            CSVDataSourceBuilder testBuilder = new CSVDataSourceBuilder().setDomain(this.source.getPreferenceDomain()).setCache(this.cacheOutput).setFile(testFiles[i]);
            GenericTTDataBuilder ttBuilder = new GenericTTDataBuilder(this.getName() + "." + i);
            dataSets.add(ttBuilder.setTest(testBuilder.build()).setTrain(trainBuilder.build()).setAttribute("DataSet", this.getName()).setAttribute("Partition", i).build());
        }
        return dataSets;
    }

    public String toString() {
        return String.format("{CXManager %s}", this.source);
    }
}

