/*
 * Decompiled with CFR 0.152.
 */
package org.openimaj.ml.linear.experiments.sinabill;

import gov.sandia.cognition.math.matrix.Matrix;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import org.openimaj.io.FileUtils;
import org.openimaj.io.IOUtils;
import org.openimaj.io.WriteableBinary;
import org.openimaj.ml.linear.data.BillMatlabFileDataGenerator;
import org.openimaj.ml.linear.evaluation.BilinearEvaluator;
import org.openimaj.ml.linear.evaluation.RootMeanSumLossEvaluator;
import org.openimaj.ml.linear.experiments.sinabill.BilinearExperiment;
import org.openimaj.ml.linear.learner.BilinearLearnerParameters;
import org.openimaj.ml.linear.learner.BilinearSparseOnlineLearner;
import org.openimaj.ml.linear.learner.init.SingleValueInitStrat;
import org.openimaj.ml.linear.learner.init.SparseZerosInitStrategy;
import org.openimaj.util.pair.Pair;

public class RegretExperiment
extends BilinearExperiment {
    private static final String BATCH_EXPERIMENT = "batchStreamLossExperiments/batch_1366231606223/experiment.log";

    @Override
    public void performExperiment() throws Exception {
        Pair<Matrix> next;
        Map<Integer, Double> batchLosses = this.loadBatchLoss();
        BilinearLearnerParameters params = new BilinearLearnerParameters();
        params.put("eta0u", 0.02);
        params.put("eta0w", 0.02);
        params.put("lambda", 0.001);
        params.put("biconvex_tol", 0.01);
        params.put("biconvex_maxiter", 10);
        params.put("bias", true);
        params.put("biaseta0", 0.5);
        params.put("winitstrat", new SingleValueInitStrat(0.1));
        params.put("uinitstrat", new SparseZerosInitStrategy());
        BillMatlabFileDataGenerator bmfdg = new BillMatlabFileDataGenerator(new File(this.MATLAB_DATA()), 98, true);
        this.prepareExperimentLog(params);
        BilinearSparseOnlineLearner learner = new BilinearSparseOnlineLearner(params);
        bmfdg.setFold(-1, BillMatlabFileDataGenerator.Mode.ALL);
        int j = 0;
        RootMeanSumLossEvaluator eval = new RootMeanSumLossEvaluator();
        eval.setLearner(learner);
        while ((next = bmfdg.generate()) != null) {
            ArrayList<Pair<Matrix>> asList = new ArrayList<Pair<Matrix>>();
            asList.add(next);
            if (learner.getW() != null) {
                if (!batchLosses.containsKey(j)) {
                    this.logger.debug((Object)String.format("...No batch result found for: %d, done", j));
                    break;
                }
                this.logger.debug((Object)("...Calculating regret for item" + j));
                double loss = ((BilinearEvaluator)eval).evaluate(asList);
                this.logger.debug((Object)String.format("... loss: %f", loss));
                double batchloss = batchLosses.get(j);
                this.logger.debug((Object)String.format("... batch loss: %f", batchloss));
                this.logger.debug((Object)String.format("... regret: %f", loss - batchloss));
            }
            learner.process((Matrix)next.firstObject(), (Matrix)next.secondObject());
            this.logger.debug((Object)String.format("... loss (post addition): %f", ((BilinearEvaluator)eval).evaluate(asList)));
            this.logger.debug((Object)String.format("Saving learner, Fold %d, Item %d", -1, j));
            File learnerOut = new File(this.FOLD_ROOT(-1), String.format("learner_%d", j));
            IOUtils.writeBinary((File)learnerOut, (WriteableBinary)learner);
            ++j;
        }
    }

    private Map<Integer, Double> loadBatchLoss() throws IOException {
        String[] batchExperimentLines = FileUtils.readlines((File)new File(this.DATA_ROOT(), BATCH_EXPERIMENT));
        int seenItems = 0;
        HashMap<Integer, Double> ret = new HashMap<Integer, Double>();
        for (String line : batchExperimentLines) {
            if (line.contains("New Item Seen: ")) {
                seenItems = Integer.parseInt(line.split(":")[1].trim());
            }
            if (!line.contains("Loss:")) continue;
            ret.put(seenItems, Double.parseDouble(line.split(":")[1].trim()));
        }
        return ret;
    }

    public static void main(String[] args) throws Exception {
        RegretExperiment exp = new RegretExperiment();
        ((BilinearExperiment)exp).performExperiment();
    }
}

