/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.ui.module.histogram;

import com.fasterxml.jackson.databind.JsonNode;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.api.storage.Persistable;
import org.deeplearning4j.api.storage.StatsStorage;
import org.deeplearning4j.api.storage.StatsStorageEvent;
import org.deeplearning4j.ui.api.FunctionType;
import org.deeplearning4j.ui.api.HttpMethod;
import org.deeplearning4j.ui.api.Route;
import org.deeplearning4j.ui.api.UIModule;
import org.deeplearning4j.ui.stats.api.StatsInitializationReport;
import org.deeplearning4j.ui.stats.api.StatsReport;
import org.deeplearning4j.ui.stats.api.StatsType;
import org.deeplearning4j.ui.stats.api.SummaryType;
import org.deeplearning4j.ui.views.html.histogram.Histogram;
import org.deeplearning4j.ui.weights.beans.CompactModelAndGradient;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import play.libs.Json;
import play.mvc.Result;
import play.mvc.Results;
import play.twirl.api.Content;

public class HistogramModule
implements UIModule {
    private static final Logger log = LoggerFactory.getLogger(HistogramModule.class);
    private Map<String, StatsStorage> knownSessionIDs = Collections.synchronizedMap(new LinkedHashMap());

    @Override
    public List<String> getCallbackTypeIDs() {
        return Collections.singletonList("StatsListener");
    }

    @Override
    public List<Route> getRoutes() {
        Route r = new Route("/weights", HttpMethod.GET, FunctionType.Supplier, () -> Results.ok((Content)Histogram.apply()));
        Route r2 = new Route("/weights/listSessions", HttpMethod.GET, FunctionType.Supplier, () -> Results.ok((JsonNode)Json.toJson(this.knownSessionIDs.keySet())));
        Route r3 = new Route("/weights/updated/:sid", HttpMethod.GET, FunctionType.Function, this::getLastUpdateTime);
        Route r4 = new Route("/weights/data/:sid", HttpMethod.GET, FunctionType.Function, this::processRequest);
        return Arrays.asList(r, r2, r3, r4);
    }

    @Override
    public void reportStorageEvents(Collection<StatsStorageEvent> events) {
        log.trace("Received events: {}", events);
        for (StatsStorageEvent sse : events) {
            if (this.knownSessionIDs.containsKey(sse.getSessionID())) continue;
            this.knownSessionIDs.put(sse.getSessionID(), sse.getStatsStorage());
        }
    }

    @Override
    public void onAttach(StatsStorage statsStorage) {
        for (String sessionID : statsStorage.listSessionIDs()) {
            for (String typeID : statsStorage.listTypeIDsForSession(sessionID)) {
                if (!"StatsListener".equals(typeID)) continue;
                this.knownSessionIDs.put(sessionID, statsStorage);
            }
        }
    }

    @Override
    public void onDetach(StatsStorage statsStorage) {
        for (String sessionID : statsStorage.listSessionIDs()) {
            this.knownSessionIDs.remove(sessionID);
        }
    }

    private Result getLastUpdateTime(String sessionID) {
        return Results.ok((JsonNode)Json.toJson((Object)System.currentTimeMillis()));
    }

    private Result processRequest(String sessionId) {
        StatsStorage ss = this.knownSessionIDs.get(sessionId);
        if (ss == null) {
            return Results.notFound((String)("Unknown session ID: " + sessionId));
        }
        List workerIDs = ss.listWorkerIDsForSession(sessionId);
        StatsInitializationReport initReport = (StatsInitializationReport)ss.getStaticInfo(sessionId, "StatsListener", (String)workerIDs.get(0));
        if (initReport == null) {
            return Results.ok((JsonNode)Json.toJson((Object)Collections.EMPTY_MAP));
        }
        String[] paramNames = initReport.getModelParamNames();
        LinkedHashSet<String> layerNameSet = new LinkedHashSet<String>();
        for (String s : paramNames) {
            String[] split = s.split("_");
            if (layerNameSet.contains(split[0])) continue;
            layerNameSet.add(split[0]);
        }
        ArrayList<String> layerNameList = new ArrayList<String>(layerNameSet);
        List list = ss.getAllUpdatesAfter(sessionId, "StatsListener", (String)workerIDs.get(0), 0L);
        Collections.sort(list, (a, b) -> Long.compare(a.getTimeStamp(), b.getTimeStamp()));
        ArrayList<Double> scoreList = new ArrayList<Double>(list.size());
        ArrayList<Map<String, List<Double>>> meanMagHistoryParams = new ArrayList<Map<String, List<Double>>>();
        ArrayList<Map<String, List<Double>>> meanMagHistoryUpdates = new ArrayList<Map<String, List<Double>>>();
        for (int i = 0; i < layerNameList.size(); ++i) {
            meanMagHistoryParams.add(new HashMap());
            meanMagHistoryUpdates.add(new HashMap());
        }
        StatsReport last = null;
        for (Persistable p : list) {
            if (!(p instanceof StatsReport)) {
                log.debug("Encountered unexpected type: {}", (Object)p);
                continue;
            }
            StatsReport sp = (StatsReport)p;
            scoreList.add(sp.getScore());
            if (sp.hasSummaryStats(StatsType.Parameters, SummaryType.MeanMagnitudes)) {
                this.updateMeanMagnitudeMaps(sp.getMeanMagnitudes(StatsType.Parameters), layerNameList, meanMagHistoryParams);
            }
            if (sp.hasSummaryStats(StatsType.Updates, SummaryType.MeanMagnitudes)) {
                this.updateMeanMagnitudeMaps(sp.getMeanMagnitudes(StatsType.Updates), layerNameList, meanMagHistoryUpdates);
            }
            last = sp;
        }
        Map<String, Map> newParams = this.getHistogram(last.getHistograms(StatsType.Parameters));
        Map<String, Map> newGrad = this.getHistogram(last.getHistograms(StatsType.Updates));
        double lastScore = scoreList.size() == 0 ? 0.0 : (Double)scoreList.get(scoreList.size() - 1);
        CompactModelAndGradient g = new CompactModelAndGradient();
        g.setGradients(newGrad);
        g.setParameters(newParams);
        g.setScore(lastScore);
        g.setScores(scoreList);
        g.setUpdateMagnitudes(meanMagHistoryUpdates);
        g.setParamMagnitudes(meanMagHistoryParams);
        g.setLastUpdateTime(last.getTimeStamp());
        return Results.ok((JsonNode)Json.toJson((Object)g));
    }

    private void updateMeanMagnitudeMaps(Map<String, Double> current, List<String> layerNames, List<Map<String, List<Double>>> history) {
        for (Map.Entry<String, Double> entry : current.entrySet()) {
            String key = entry.getKey();
            String[] split = key.split("_");
            int idx = layerNames.indexOf(split[0]);
            Map<String, List<Double>> map = history.get(idx);
            List<Double> l = map.get(key);
            if (l == null) {
                l = new ArrayList<Double>();
                map.put(key, l);
            }
            l.add(entry.getValue());
        }
    }

    private Map<String, Map> getHistogram(Map<String, org.deeplearning4j.ui.stats.api.Histogram> histograms) {
        LinkedHashMap<String, Map> ret = new LinkedHashMap<String, Map>();
        for (String s : histograms.keySet()) {
            org.deeplearning4j.ui.stats.api.Histogram h = histograms.get(s);
            String newName = Character.isDigit(s.charAt(0)) ? "param_" + s : s;
            LinkedHashMap<Double, Integer> temp = new LinkedHashMap<Double, Integer>();
            double min = h.getMin();
            double max = h.getMax();
            int n = h.getNBins();
            double step = (max - min) / (double)n;
            int[] counts = h.getBinCounts();
            for (int i = 0; i < n; ++i) {
                double binLoc = min + (double)i * step + step / 2.0;
                temp.put(binLoc, counts[i]);
            }
            ret.put(newName, temp);
        }
        return ret;
    }
}

