/*
 * Decompiled with CFR 0.152.
 */
package org.openml.weka.experiment;

import com.thoughtworks.xstream.XStream;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.StringUtils;
import org.openml.apiconnector.algorithms.Conversion;
import org.openml.apiconnector.algorithms.SciMark;
import org.openml.apiconnector.algorithms.TaskInformation;
import org.openml.apiconnector.io.OpenmlConnector;
import org.openml.apiconnector.models.MetricScore;
import org.openml.apiconnector.xml.EvaluationScore;
import org.openml.apiconnector.xml.Flow;
import org.openml.apiconnector.xml.Run;
import org.openml.apiconnector.xml.Task;
import org.openml.apiconnector.xml.UploadRun;
import org.openml.apiconnector.xstream.XstreamXmlMapping;
import org.openml.weka.algorithm.OptimizationTrace;
import org.openml.weka.algorithm.WekaAlgorithm;
import org.openml.weka.algorithm.WekaConfig;
import org.openml.weka.experiment.OpenmlSplitEvaluator;
import weka.classifiers.Classifier;
import weka.classifiers.evaluation.NominalPrediction;
import weka.classifiers.evaluation.Prediction;
import weka.core.Attribute;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.RevisionHandler;
import weka.core.Utils;
import weka.core.Version;
import weka.experiment.InstancesResultListener;

public class TaskResultListener
extends InstancesResultListener {
    private static final long serialVersionUID = 7230120341L;
    private static final String[] DEFAULT_TAGS = new String[]{"weka", "weka_" + Version.VERSION};
    private final Map<String, OpenmlExecutedTask> currentlyCollecting;
    private final List<String> tasksWithErrors;
    private final OpenmlConnector apiconnector;
    private final String[] all_tags;
    private final List<Integer> runIds;
    boolean skipJvmBenchmark = false;

    public TaskResultListener(OpenmlConnector apiconnector, WekaConfig config) {
        this.apiconnector = apiconnector;
        this.currentlyCollecting = new HashMap<String, OpenmlExecutedTask>();
        this.tasksWithErrors = new ArrayList<String>();
        this.all_tags = (String[])ArrayUtils.addAll((Object[])DEFAULT_TAGS, (Object[])config.getTags());
        this.skipJvmBenchmark = config.getSkipJvmBenchmark();
        this.runIds = new ArrayList<Integer>();
    }

    public void acceptFullModel(Task t, Instances sourceData, Classifier classifier, String options, Map<String, Object> splitEvaluatorResults, OpenmlSplitEvaluator tse) throws Exception {
        String revision = classifier instanceof RevisionHandler ? ((RevisionHandler)classifier).getRevision() : "undefined";
        String implementationId = classifier.getClass().getName() + "(" + revision + ")";
        String key = t.getTask_id() + "_" + implementationId + "_" + options;
        if (!this.currentlyCollecting.containsKey(key)) {
            this.currentlyCollecting.put(key, new OpenmlExecutedTask(t, classifier, sourceData, null, options, this.apiconnector, true, this.all_tags));
        }
        OpenmlExecutedTask oet = this.currentlyCollecting.get(key);
        oet.modelFullDataset(splitEvaluatorResults, tse);
        if (oet.complete()) {
            int runId = this.sendTask(oet);
            this.currentlyCollecting.remove(key);
            this.runIds.add(runId);
        }
    }

    public void acceptResultsForSending(Task t, Instances sourceData, Integer repeat, Integer fold, Integer sample, Classifier classifier, String options, List<Integer> rowids, ArrayList<Prediction> predictions, Map<String, MetricScore> userMeasures, List<OptimizationTrace.Quadlet<String, Double, List<Map.Entry<String, Object>>, Boolean>> optimizationTrace, boolean wantFullModel) throws Exception {
        String revision = classifier instanceof RevisionHandler ? ((RevisionHandler)classifier).getRevision() : "undefined";
        String implementationId = classifier.getClass().getName() + "(" + revision + ")";
        String key = t.getTask_id() + "_" + implementationId + "_" + options;
        if (!this.currentlyCollecting.containsKey(key)) {
            this.currentlyCollecting.put(key, new OpenmlExecutedTask(t, classifier, sourceData, null, options, this.apiconnector, wantFullModel, this.all_tags));
        }
        OpenmlExecutedTask oet = this.currentlyCollecting.get(key);
        oet.addBatchOfPredictions(fold, repeat, sample, rowids, predictions, optimizationTrace);
        oet.addUserDefinedMeasures(fold, repeat, sample, userMeasures);
        if (oet.complete()) {
            int runId = this.sendTask(oet);
            this.currentlyCollecting.remove(key);
            this.runIds.add(runId);
        }
    }

    public void acceptErrorResult(Task t, Instances sourceData, Classifier classifier, String error_message, String options) throws Exception {
        String revision = classifier instanceof RevisionHandler ? ((RevisionHandler)classifier).getRevision() : "undefined";
        String implementationId = classifier.getClass().getName() + "(" + revision + ")";
        String key = t.getTask_id() + "_" + implementationId + "_" + options;
        if (!this.tasksWithErrors.contains(key)) {
            this.tasksWithErrors.add(key);
            int runId = this.sendTaskWithError(new OpenmlExecutedTask(t, classifier, sourceData, error_message, options, this.apiconnector, false, this.all_tags));
            this.runIds.add(runId);
        }
    }

    private int sendTask(OpenmlExecutedTask oet) throws Exception {
        Conversion.log((String)"INFO", (String)"Upload Run", (String)"Starting send run process... ");
        XStream xstream = XstreamXmlMapping.getInstance();
        SciMark benchmarker = SciMark.getInstance();
        oet.getRun().addOutputEvaluation(new EvaluationScore("os_information", null, null, "['" + StringUtils.join((Object[])benchmarker.getOsInfo(), (String)"', '") + "']"));
        if (!this.skipJvmBenchmark) {
            oet.getRun().addOutputEvaluation(new EvaluationScore("scimark_benchmark", benchmarker.getResult() + "", null, "[" + StringUtils.join((Object[])benchmarker.getStringArray(), (String)", ") + "]"));
        }
        File tmpPredictionsFile = Conversion.stringToTempFile((String)oet.getPredictions().toString(), (String)"weka_generated_predictions", (String)"arff");
        File tmpDescriptionFile = Conversion.stringToTempFile((String)xstream.toXML((Object)oet.getRun()), (String)"weka_generated_run", (String)"xml");
        HashMap<String, File> output_files = new HashMap<String, File>();
        output_files.put("predictions", tmpPredictionsFile);
        if (oet.serializedClassifier != null) {
            output_files.put("model_serialized", oet.serializedClassifier);
        }
        if (oet.humanReadableClassifier != null) {
            output_files.put("model_readable", oet.humanReadableClassifier);
        }
        if (oet.optimizationTrace != null) {
            output_files.put("trace", Conversion.stringToTempFile((String)oet.optimizationTrace.toString(), (String)"optimization_trace", (String)"arff"));
        }
        UploadRun ur = this.apiconnector.runUpload(tmpDescriptionFile, output_files);
        return ur.getRun_id();
    }

    private int sendTaskWithError(OpenmlExecutedTask oet) throws Exception {
        Conversion.log((String)"WARNING", (String)"Upload Run", (String)"Starting to upload run... (including error results) ");
        XStream xstream = XstreamXmlMapping.getInstance();
        File tmpDescriptionFile = Conversion.stringToTempFile((String)xstream.toXML((Object)oet.getRun()), (String)"weka_generated_run", (String)"arff");
        UploadRun ur = this.apiconnector.runUpload(tmpDescriptionFile, new HashMap());
        return ur.getRun_id();
    }

    public List<Integer> getRunIds() {
        return this.runIds;
    }

    private class OpenmlExecutedTask {
        private final boolean isRegression;
        private int task_id;
        private Task task;
        private Instances predictions;
        private Instances inputData;
        private Instances optimizationTrace;
        private int nrOfResultBatches;
        private final int nrOfExpectedResultBatches;
        private String[] classnames;
        private Run run;
        private int implementation_id;
        private boolean waitForFullModel;
        private boolean hasFullModel;
        private int repeats;
        private int samples;
        private File serializedClassifier = null;
        private File humanReadableClassifier = null;

        public OpenmlExecutedTask(Task t, Classifier classifier, Instances sourceData, String error_message, String options, OpenmlConnector apiconnector, boolean waitForFullModel, String[] tags) throws Exception {
            this.task = t;
            this.waitForFullModel = waitForFullModel;
            this.hasFullModel = false;
            this.isRegression = t.getTask_type().equals("Supervised Regression");
            this.inputData = sourceData;
            this.optimizationTrace = null;
            if (!this.isRegression) {
                this.classnames = TaskInformation.getClassNames((OpenmlConnector)apiconnector, (Task)this.task);
            }
            this.task_id = this.task.getTask_id();
            this.repeats = 1;
            int folds = 1;
            this.samples = 1;
            try {
                this.repeats = TaskInformation.getNumberOfRepeats((Task)t);
            }
            catch (Exception exception) {
                // empty catch block
            }
            try {
                folds = TaskInformation.getNumberOfFolds((Task)t);
            }
            catch (Exception exception) {
                // empty catch block
            }
            try {
                this.samples = TaskInformation.getNumberOfSamples((Task)t);
            }
            catch (Exception exception) {
                // empty catch block
            }
            this.nrOfExpectedResultBatches = this.repeats * folds * this.samples;
            this.nrOfResultBatches = 0;
            ArrayList<Attribute> attInfo = new ArrayList<Attribute>();
            for (Task.Output.Predictions.Feature f : TaskInformation.getPredictions((Task)t).getFeatures()) {
                if (f.getName().equals("confidence.classname")) {
                    for (String s : TaskInformation.getClassNames((OpenmlConnector)apiconnector, (Task)t)) {
                        attInfo.add(new Attribute("confidence." + s));
                    }
                    continue;
                }
                if (f.getName().equals("prediction")) {
                    if (this.isRegression) {
                        attInfo.add(new Attribute("prediction"));
                        continue;
                    }
                    ArrayList<String> values = new ArrayList<String>(this.classnames.length);
                    for (String classname : this.classnames) {
                        values.add(classname);
                    }
                    attInfo.add(new Attribute(f.getName(), values));
                    continue;
                }
                attInfo.add(new Attribute(f.getName()));
            }
            attInfo.add(this.inputData.classAttribute().copy("correct"));
            this.predictions = new Instances("openml_task_" + t.getTask_id() + "_predictions", attInfo, 0);
            Flow find = WekaAlgorithm.serializeClassifier(classifier.getClass().getName(), tags);
            this.implementation_id = WekaAlgorithm.getImplementationId(find, classifier, apiconnector);
            Flow implementation = apiconnector.flowGet(this.implementation_id);
            String setup_string = classifier.getClass().getName();
            if (!options.equals("")) {
                setup_string = setup_string + " -- " + options;
            }
            String[] params = Utils.splitOptions((String)options);
            ArrayList<Run.Parameter_setting> list = WekaAlgorithm.getParameterSetting(params, implementation);
            this.run = new Run(t.getTask_id(), error_message, implementation.getId().intValue(), setup_string, list.toArray(new Run.Parameter_setting[list.size()]), tags);
        }

        public void addBatchOfPredictions(Integer fold, Integer repeat, Integer sample, List<Integer> rowids, ArrayList<Prediction> batchPredictions, List<OptimizationTrace.Quadlet<String, Double, List<Map.Entry<String, Object>>, Boolean>> optimizationTraceFold) {
            ++this.nrOfResultBatches;
            for (int i = 0; i < rowids.size(); ++i) {
                Prediction current = batchPredictions.get(i);
                double[] values = new double[this.predictions.numAttributes()];
                values[this.predictions.attribute((String)"row_id").index()] = rowids.get(i).intValue();
                values[this.predictions.attribute((String)"fold").index()] = fold.intValue();
                values[this.predictions.attribute((String)"repeat").index()] = repeat.intValue();
                values[this.predictions.attribute((String)"prediction").index()] = current.predicted();
                if (this.predictions.attribute("sample") != null) {
                    values[this.predictions.attribute((String)"sample").index()] = sample.intValue();
                }
                values[this.predictions.attribute((String)"correct").index()] = this.inputData.instance(rowids.get(i).intValue()).classValue();
                if (current instanceof NominalPrediction) {
                    double[] confidences = ((NominalPrediction)current).distribution();
                    for (int j = 0; j < confidences.length; ++j) {
                        values[this.predictions.attribute((String)new StringBuilder().append((String)"confidence.").append((String)this.classnames[j]).toString()).index()] = confidences[j];
                    }
                }
                this.predictions.add((Instance)new DenseInstance(1.0, values));
            }
            if (optimizationTraceFold != null) {
                this.optimizationTrace = OptimizationTrace.addTraceToDataset(this.optimizationTrace, optimizationTraceFold, this.task_id, repeat, fold);
            }
        }

        public void addUserDefinedMeasures(Integer fold, Integer repeat, Integer sample, Map<String, MetricScore> userMeasures) throws Exception {
            for (String m : userMeasures.keySet()) {
                MetricScore score = userMeasures.get(m);
                this.getRun().addOutputEvaluation(new EvaluationScore(m, score.getScore() + "", null, repeat, fold, sample, null));
            }
        }

        public void modelFullDataset(Map<String, Object> splitEvaluatorResults, OpenmlSplitEvaluator tse) {
            Classifier classifierModel = tse.getClassifier();
            this.hasFullModel = true;
            String keyTraining = "UserCPU_Time_millis_training";
            String keyTesting = "UserCPU_Time_millis_testing";
            if (splitEvaluatorResults.containsKey(keyTraining) && splitEvaluatorResults.containsKey(keyTesting)) {
                Double totalTimeTraining = (Double)splitEvaluatorResults.get(keyTraining);
                Double totalTimeTesting = (Double)splitEvaluatorResults.get(keyTesting);
                Double totalTime = totalTimeTesting + totalTimeTraining;
                this.getRun().addOutputEvaluation(new EvaluationScore(keyTesting.toLowerCase(), "" + totalTimeTesting, null, null));
                this.getRun().addOutputEvaluation(new EvaluationScore(keyTraining.toLowerCase(), "" + totalTimeTraining, null, null));
                this.getRun().addOutputEvaluation(new EvaluationScore("usercpu_time_millis", "" + totalTime, null, null));
            }
            try {
                this.humanReadableClassifier = Conversion.stringToTempFile((String)classifierModel.toString(), (String)("WekaModel_" + classifierModel.getClass().getName()), (String)"model");
            }
            catch (IOException ioe) {
                Conversion.log((String)"Warning", (String)"Model", (String)"Problem extracting human readible model. ");
            }
            try {
                this.serializedClassifier = WekaAlgorithm.classifierSerializedToFile(classifierModel, this.task_id);
            }
            catch (IOException ioe) {
                Conversion.log((String)"Warning", (String)"Model", (String)"Problem extracting serializable model. ");
            }
        }

        public Run getRun() {
            return this.run;
        }

        public Instances getPredictions() {
            return this.predictions;
        }

        public boolean complete() {
            boolean allFolds;
            boolean bl = allFolds = this.nrOfResultBatches == this.nrOfExpectedResultBatches;
            if (this.waitForFullModel) {
                return allFolds && this.hasFullModel;
            }
            return allFolds;
        }
    }
}

