/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.arbiter.ui.listener;

import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import lombok.NonNull;
import org.deeplearning4j.api.storage.Persistable;
import org.deeplearning4j.api.storage.StatsStorageRouter;
import org.deeplearning4j.arbiter.optimize.api.OptimizationResult;
import org.deeplearning4j.arbiter.optimize.runner.CandidateInfo;
import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner;
import org.deeplearning4j.arbiter.optimize.runner.listener.StatusListener;
import org.deeplearning4j.arbiter.ui.data.GlobalConfigPersistable;
import org.deeplearning4j.arbiter.ui.data.ModelInfoPersistable;
import org.deeplearning4j.arbiter.ui.misc.JsonMapper;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.eclipse.collections.impl.list.mutable.primitive.FloatArrayList;
import org.eclipse.collections.impl.list.mutable.primitive.IntArrayList;
import org.nd4j.linalg.primitives.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ArbiterStatusListener
implements StatusListener {
    private static final Logger log = LoggerFactory.getLogger(ArbiterStatusListener.class);
    public static final int MAX_SCORE_VS_ITER_PTS = 1024;
    private final String sessionId;
    private final StatsStorageRouter statsStorage;
    private String ocJson;
    private long startTime = 0L;
    private Map<Integer, Integer> candidateScoreVsIterSubsampleFreq = new ConcurrentHashMap<Integer, Integer>();
    private Map<Integer, Pair<IntArrayList, FloatArrayList>> candidateScoreVsIter = new ConcurrentHashMap<Integer, Pair<IntArrayList, FloatArrayList>>();
    private Map<Integer, ModelInfoPersistable> lastModelInfoPersistable = new ConcurrentHashMap<Integer, ModelInfoPersistable>();

    public ArbiterStatusListener(@NonNull StatsStorageRouter statsStorage) {
        this(UUID.randomUUID().toString(), statsStorage);
        if (statsStorage == null) {
            throw new NullPointerException("statsStorage");
        }
    }

    public ArbiterStatusListener(@NonNull String sessionId, @NonNull StatsStorageRouter statsStorage) {
        if (sessionId == null) {
            throw new NullPointerException("sessionId");
        }
        if (statsStorage == null) {
            throw new NullPointerException("statsStorage");
        }
        this.sessionId = sessionId;
        this.statsStorage = statsStorage;
    }

    public void onInitialization(IOptimizationRunner r) {
        GlobalConfigPersistable p = this.getNewStatusPersistable(r);
        this.statsStorage.putStaticInfo((Persistable)p);
    }

    public void onShutdown(IOptimizationRunner runner) {
    }

    public void onRunnerStatusChange(IOptimizationRunner r) {
        GlobalConfigPersistable p = this.getNewStatusPersistable(r);
        this.statsStorage.putStaticInfo((Persistable)p);
    }

    public void onCandidateStatusChange(CandidateInfo candidateInfo, IOptimizationRunner runner, OptimizationResult result) {
        ModelInfoPersistable p = this.lastModelInfoPersistable.get(candidateInfo.getIndex());
        if (p == null) {
            p = ((ModelInfoPersistable.Builder)((ModelInfoPersistable.Builder)new ModelInfoPersistable.Builder().timestamp(candidateInfo.getCreatedTime())).sessionId(this.sessionId)).workerId(String.valueOf(candidateInfo.getIndex())).modelIdx(candidateInfo.getIndex()).score(candidateInfo.getScore()).status(candidateInfo.getCandidateStatus()).exceptionStackTrace(candidateInfo.getExceptionStackTrace()).build();
            this.lastModelInfoPersistable.put(candidateInfo.getIndex(), p);
        }
        if (p.getScore() == null) {
            p.setScore(candidateInfo.getScore());
        }
        if (result != null && p.getExceptionStackTrace() == null && result.getCandidateInfo().getExceptionStackTrace() != null) {
            p.setExceptionStackTrace(result.getCandidateInfo().getExceptionStackTrace());
        }
        p.setStatus(candidateInfo.getCandidateStatus());
        this.statsStorage.putUpdate((Persistable)p);
    }

    public void onCandidateIteration(CandidateInfo candidateInfo, Object candidate, int iteration) {
        int i;
        int totalNumUpdates;
        String modelConfigJson;
        int numLayers;
        long numParams;
        double score;
        if (candidate instanceof MultiLayerNetwork) {
            MultiLayerNetwork m = (MultiLayerNetwork)candidate;
            score = m.score();
            numParams = m.numParams();
            numLayers = m.getnLayers();
            modelConfigJson = m.getLayerWiseConfigurations().toJson();
            totalNumUpdates = m.getLayerWiseConfigurations().getIterationCount();
        } else if (candidate instanceof ComputationGraph) {
            ComputationGraph cg = (ComputationGraph)candidate;
            score = cg.score();
            numParams = cg.numParams();
            numLayers = cg.getNumLayers();
            modelConfigJson = cg.getConfiguration().toJson();
            totalNumUpdates = cg.getConfiguration().getIterationCount();
        } else {
            score = 0.0;
            numParams = 0L;
            numLayers = 0;
            totalNumUpdates = 0;
            modelConfigJson = "";
        }
        int idx = candidateInfo.getIndex();
        Pair pair = this.candidateScoreVsIter.computeIfAbsent(idx, k -> new Pair((Object)new IntArrayList(), (Object)new FloatArrayList()));
        IntArrayList iter = (IntArrayList)pair.getFirst();
        FloatArrayList scores = (FloatArrayList)pair.getSecond();
        int subsamplingFreq = this.candidateScoreVsIterSubsampleFreq.computeIfAbsent(idx, k -> 1);
        if (iteration / subsamplingFreq > 1024) {
            this.candidateScoreVsIterSubsampleFreq.put(idx, subsamplingFreq *= 2);
            IntArrayList newIter = new IntArrayList();
            FloatArrayList newScores = new FloatArrayList();
            for (i = 0; i < iter.size(); ++i) {
                int it = iter.get(i);
                if (it % subsamplingFreq != 0) continue;
                newIter.add(it);
                newScores.add(scores.get(i));
            }
            iter = newIter;
            scores = newScores;
            this.candidateScoreVsIter.put(idx, (Pair<IntArrayList, FloatArrayList>)new Pair((Object)iter, (Object)scores));
        }
        if (iteration % subsamplingFreq == 0) {
            iter.add(iteration);
            scores.add((float)score);
        }
        int[] iters = iter.toArray();
        float[] fScores = new float[iters.length];
        for (i = 0; i < iters.length; ++i) {
            fScores[i] = scores.get(i);
        }
        ModelInfoPersistable p = ((ModelInfoPersistable.Builder)((ModelInfoPersistable.Builder)new ModelInfoPersistable.Builder().timestamp(candidateInfo.getCreatedTime())).sessionId(this.sessionId)).workerId(String.valueOf(candidateInfo.getIndex())).modelIdx(candidateInfo.getIndex()).score(candidateInfo.getScore()).status(candidateInfo.getCandidateStatus()).scoreVsIter(iters, fScores).lastUpdateTime(System.currentTimeMillis()).numParameters(numParams).numLayers(numLayers).totalNumUpdates(totalNumUpdates).paramSpaceValues(candidateInfo.getFlatParams()).modelConfigJson(modelConfigJson).exceptionStackTrace(candidateInfo.getExceptionStackTrace()).build();
        this.lastModelInfoPersistable.put(candidateInfo.getIndex(), p);
        this.statsStorage.putUpdate((Persistable)p);
    }

    private GlobalConfigPersistable getNewStatusPersistable(IOptimizationRunner r) {
        this.ocJson = JsonMapper.asJson(r.getConfiguration());
        GlobalConfigPersistable p = ((GlobalConfigPersistable.Builder)((GlobalConfigPersistable.Builder)new GlobalConfigPersistable.Builder().sessionId(this.sessionId)).timestamp(System.currentTimeMillis())).optimizationConfigJson(this.ocJson).candidateCounts(r.numCandidatesQueued(), r.numCandidatesCompleted(), r.numCandidatesFailed(), r.numCandidatesTotal()).optimizationRunner(r.getClass().getSimpleName()).build();
        return p;
    }
}

