/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.parallelism.trainer;

import java.beans.ConstructorProperties;
import java.util.ArrayList;
import java.util.Collection;
import java.util.UUID;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.LockSupport;
import lombok.NonNull;
import org.deeplearning4j.api.storage.StatsStorageRouter;
import org.deeplearning4j.api.storage.listener.RoutingIterationListener;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.Updater;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.parallelism.ParallelWrapper;
import org.deeplearning4j.parallelism.trainer.Trainer;
import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DefaultTrainer
extends Thread
implements Trainer {
    private static final Logger log = LoggerFactory.getLogger(DefaultTrainer.class);
    protected Model replicatedModel;
    protected LinkedBlockingQueue<org.nd4j.linalg.dataset.api.DataSet> queue;
    protected LinkedBlockingQueue<MultiDataSet> queueMDS;
    protected AtomicInteger running;
    protected AtomicBoolean shouldUpdate;
    protected AtomicBoolean shouldStop;
    protected Exception thrownException;
    protected volatile boolean useMDS;
    protected final String uuid = UUID.randomUUID().toString();
    protected boolean onRootModel;
    protected volatile AtomicLong lastEtlTime;
    protected AtomicBoolean nullMode;
    protected org.nd4j.linalg.dataset.api.DataSet nullDataSet;
    protected AtomicBoolean isStopped;
    protected ParallelWrapper parallelWrapper;
    protected WorkspaceMode workspaceMode;
    protected int averagingFrequency;
    protected int threadId;
    protected Model originalModel;

    @Override
    public void feedMultiDataSet(@NonNull MultiDataSet dataSet, long etlTime) {
        if (dataSet == null) {
            throw new NullPointerException("dataSet");
        }
        this.setupIfNeccessary();
        try {
            this.queueMDS.put(dataSet);
            this.running.incrementAndGet();
        }
        catch (InterruptedException interruptedException) {
            // empty catch block
        }
        if (this.lastEtlTime == null) {
            this.lastEtlTime = new AtomicLong(0L);
        }
        this.lastEtlTime.set(etlTime);
    }

    @Override
    public void feedDataSet(org.nd4j.linalg.dataset.api.DataSet dataSet, long etlTime) {
        this.setupIfNeccessary();
        if (dataSet != null) {
            try {
                this.queue.put(dataSet);
                this.running.incrementAndGet();
            }
            catch (InterruptedException interruptedException) {}
        } else {
            if (this.nullMode == null) {
                this.nullMode = new AtomicBoolean(false);
            }
            this.nullMode.set(true);
        }
        if (this.lastEtlTime == null) {
            this.lastEtlTime = new AtomicLong(0L);
        }
        this.lastEtlTime.set(etlTime);
    }

    @Override
    public Model getModel() {
        return this.replicatedModel;
    }

    @Override
    public void updateModel(@NonNull Model model) {
        if (model == null) {
            throw new NullPointerException("model");
        }
        this.shouldUpdate.set(true);
        if (this.replicatedModel instanceof MultiLayerNetwork) {
            this.replicatedModel.setParams(model.params().dup());
            Updater updater = ((MultiLayerNetwork)model).getUpdater();
            INDArray view = updater.getStateViewArray();
            if (view != null) {
                updater = ((MultiLayerNetwork)this.replicatedModel).getUpdater();
                INDArray viewD = view.dup();
                Nd4j.getExecutioner().commit();
                updater.setStateViewArray((Layer)((MultiLayerNetwork)this.replicatedModel), viewD, false);
            }
        } else if (this.replicatedModel instanceof ComputationGraph) {
            this.replicatedModel.setParams(model.params().dup());
            ComputationGraphUpdater updater = ((ComputationGraph)model).getUpdater();
            INDArray view = updater.getStateViewArray();
            if (view != null) {
                INDArray viewD = view.dup();
                Nd4j.getExecutioner().commit();
                updater = ((ComputationGraph)this.replicatedModel).getUpdater();
                updater.setStateViewArray(viewD);
            }
        }
        Nd4j.getExecutioner().commit();
    }

    protected void setupIfNeccessary() {
        if (this.queue == null) {
            this.queue = new LinkedBlockingQueue(1);
        }
        if (this.queueMDS == null) {
            this.queueMDS = new LinkedBlockingQueue(1);
        }
        if (this.running == null) {
            this.running = new AtomicInteger(0);
        }
        if (this.shouldStop == null) {
            this.shouldStop = new AtomicBoolean(false);
        }
        if (this.shouldUpdate == null) {
            this.shouldUpdate = new AtomicBoolean(false);
        }
        if (this.isStopped == null) {
            this.isStopped = new AtomicBoolean(false);
        }
        if (this.lastEtlTime == null) {
            this.lastEtlTime = new AtomicLong(0L);
        }
    }

    @Override
    public boolean isRunning() {
        if (this.thrownException != null) {
            throw new RuntimeException(this.thrownException);
        }
        return this.running.get() == 0;
    }

    @Override
    public void shutdown() {
        this.shouldStop.set(true);
        while (!this.isStopped.get()) {
            LockSupport.parkNanos(1000L);
        }
        this.shouldStop.set(false);
        this.isStopped.set(false);
    }

    protected void fit(org.nd4j.linalg.dataset.api.DataSet dataSet) {
        if (this.replicatedModel instanceof MultiLayerNetwork) {
            if (this.lastEtlTime == null) {
                this.lastEtlTime = new AtomicLong(0L);
            }
            ((MultiLayerNetwork)this.replicatedModel).setLastEtlTime(this.lastEtlTime.get());
            ((MultiLayerNetwork)this.replicatedModel).fit(dataSet);
        } else if (this.replicatedModel instanceof ComputationGraph) {
            if (this.lastEtlTime == null) {
                this.lastEtlTime = new AtomicLong(0L);
            }
            ((ComputationGraph)this.replicatedModel).setLastEtlTime(this.lastEtlTime.get());
            ((ComputationGraph)this.replicatedModel).fit(dataSet);
        }
    }

    protected void fit(MultiDataSet dataSet) {
        if (this.lastEtlTime == null) {
            this.lastEtlTime = new AtomicLong(0L);
        }
        ((ComputationGraph)this.replicatedModel).setLastEtlTime(this.lastEtlTime.get());
        ((ComputationGraph)this.replicatedModel).fit(dataSet);
    }

    protected void postInit() {
        ArrayList<IterationListener> oldListeners = new ArrayList<IterationListener>();
        ArrayList<IterationListener> replicatedListeners = new ArrayList<IterationListener>();
        if (this.parallelWrapper.getListeners() != null) {
            oldListeners.addAll(this.parallelWrapper.getListeners());
        }
        this.configureListeners(this.uuid, oldListeners, replicatedListeners);
        this.replicatedModel.setListeners(replicatedListeners);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void run() {
        this.setupIfNeccessary();
        AtomicInteger iterationsCounter = new AtomicInteger(0);
        if (this.threadId == 0) {
            this.onRootModel = true;
        }
        try {
            ComputationGraphUpdater updaterReplica;
            MultiDataSet dataSet;
            Updater updaterOrigina;
            Updater updaterReplica2;
            Model model;
            MultiLayerConfiguration conf;
            if (this.originalModel instanceof MultiLayerNetwork) {
                if (!this.onRootModel) {
                    conf = MultiLayerConfiguration.fromJson((String)((MultiLayerNetwork)this.originalModel).getLayerWiseConfigurations().toJson());
                    conf.setTrainingWorkspaceMode(this.workspaceMode);
                    this.replicatedModel = new MultiLayerNetwork(conf);
                    this.replicatedModel.init();
                    model = this.originalModel;
                    synchronized (model) {
                        this.replicatedModel.setParams(this.originalModel.params().unsafeDuplication(true));
                        updaterReplica2 = ((MultiLayerNetwork)this.replicatedModel).getUpdater();
                        updaterOrigina = ((MultiLayerNetwork)this.originalModel).getUpdater();
                        if (updaterOrigina != null && updaterOrigina.getStateViewArray() != null) {
                            updaterReplica2.setStateViewArray((Layer)((MultiLayerNetwork)this.replicatedModel), updaterOrigina.getStateViewArray().unsafeDuplication(true), false);
                        }
                        Nd4j.getExecutioner().commit();
                    }
                } else {
                    this.replicatedModel = this.originalModel;
                    if (!((MultiLayerNetwork)this.replicatedModel).isInitCalled()) {
                        this.replicatedModel.init();
                    }
                    ((MultiLayerNetwork)this.replicatedModel).getLayerWiseConfigurations().setTrainingWorkspaceMode(this.workspaceMode);
                }
            } else if (this.originalModel instanceof ComputationGraph) {
                if (!this.onRootModel) {
                    conf = ComputationGraphConfiguration.fromJson((String)((ComputationGraph)this.originalModel).getConfiguration().toJson());
                    conf.setTrainingWorkspaceMode(this.workspaceMode);
                    this.replicatedModel = new ComputationGraph((ComputationGraphConfiguration)conf);
                    this.replicatedModel.init();
                    model = this.originalModel;
                    synchronized (model) {
                        this.replicatedModel.setParams(this.originalModel.params().unsafeDuplication(true));
                        updaterReplica2 = ((ComputationGraph)this.replicatedModel).getUpdater();
                        updaterOrigina = ((ComputationGraph)this.originalModel).getUpdater();
                        if (updaterOrigina != null && updaterOrigina.getStateViewArray() != null) {
                            updaterReplica2.setStateViewArray(updaterOrigina.getStateViewArray().unsafeDuplication(true));
                        }
                        Nd4j.getExecutioner().commit();
                    }
                } else {
                    this.replicatedModel = this.originalModel;
                    this.replicatedModel.init();
                    ((ComputationGraph)this.replicatedModel).getConfiguration().setTrainingWorkspaceMode(this.workspaceMode);
                }
            }
            if (this.replicatedModel == null) {
                log.error("replicatedModel is NULL at worker_{}", (Object)this.threadId);
            }
            this.postInit();
            if (!this.useMDS) {
                while (!this.shouldStop.get()) {
                    dataSet = null;
                    if (this.nullMode == null || !this.nullMode.get()) {
                        dataSet = this.queue.poll(10L, TimeUnit.MILLISECONDS);
                    } else {
                        if (this.nullDataSet == null) {
                            this.nullDataSet = new DataSet(Nd4j.create((int)64, (int)784), Nd4j.create((int)64, (int)10));
                        }
                        dataSet = this.nullDataSet;
                    }
                    if (dataSet == null) continue;
                    this.fit((org.nd4j.linalg.dataset.api.DataSet)dataSet);
                    if (!Nd4j.getAffinityManager().isCrossDeviceAccessSupported() && (this.averagingFrequency == 0 || iterationsCounter.incrementAndGet() % this.averagingFrequency == 0) && this.averagingRequired()) {
                        Nd4j.getExecutioner().commit();
                        Nd4j.getAffinityManager().ensureLocation(this.replicatedModel.params(), AffinityManager.Location.HOST);
                        if (this.replicatedModel instanceof MultiLayerNetwork) {
                            updaterReplica = ((MultiLayerNetwork)this.replicatedModel).getUpdater();
                            if (updaterReplica.getStateViewArray() != null) {
                                Nd4j.getAffinityManager().ensureLocation(updaterReplica.getStateViewArray(), AffinityManager.Location.HOST);
                            }
                        } else {
                            updaterReplica = ((ComputationGraph)this.replicatedModel).getUpdater();
                            if (updaterReplica.getStateViewArray() != null) {
                                Nd4j.getAffinityManager().ensureLocation(updaterReplica.getStateViewArray(), AffinityManager.Location.HOST);
                            }
                        }
                    }
                    this.running.decrementAndGet();
                }
            } else {
                while (!this.shouldStop.get()) {
                    dataSet = this.queueMDS.poll(10L, TimeUnit.MILLISECONDS);
                    if (dataSet == null) continue;
                    this.fit(dataSet);
                    if (!Nd4j.getAffinityManager().isCrossDeviceAccessSupported() && (this.averagingFrequency == 0 || iterationsCounter.incrementAndGet() % this.averagingFrequency == 0) && this.averagingRequired()) {
                        Nd4j.getExecutioner().commit();
                        Nd4j.getAffinityManager().ensureLocation(this.replicatedModel.params(), AffinityManager.Location.HOST);
                        updaterReplica = ((ComputationGraph)this.replicatedModel).getUpdater();
                        if (updaterReplica.getStateViewArray() != null) {
                            Nd4j.getAffinityManager().ensureLocation(updaterReplica.getStateViewArray(), AffinityManager.Location.HOST);
                        }
                    }
                    this.running.decrementAndGet();
                }
            }
        }
        catch (Exception e) {
            this.thrownException = e;
            throw new RuntimeException(e);
        }
        finally {
            log.debug("Terminating all workspaces for trainer_{}", (Object)this.threadId);
            Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread();
            this.isStopped.set(true);
        }
    }

    @Override
    public void waitTillRunning() {
        while (this.running.get() != 0) {
            if (this.thrownException != null) {
                throw new RuntimeException(this.thrownException);
            }
            LockSupport.parkNanos(1000L);
        }
    }

    @Override
    public boolean averagingRequired() {
        return true;
    }

    protected static IterationListener cloneListener(IterationListener original) {
        if (original instanceof RoutingIterationListener) {
            return ((RoutingIterationListener)original).clone();
        }
        return original;
    }

    protected void configureListeners(String workerUUID, Collection<IterationListener> oldListeners, Collection<IterationListener> replicatedListeners) {
        for (IterationListener listener : oldListeners) {
            IterationListener l = DefaultTrainer.cloneListener(listener);
            if (l instanceof RoutingIterationListener) {
                RoutingIterationListener rl = (RoutingIterationListener)l;
                rl.setSessionID(((RoutingIterationListener)listener).getSessionID());
                rl.setWorkerID(workerUUID);
                StatsStorageRouter currentRouter = ((RoutingIterationListener)listener).getStorageRouter();
                if (currentRouter != null) {
                    rl.setStorageRouter(currentRouter);
                } else {
                    rl.setStorageRouter(this.parallelWrapper.getStorageRouter());
                }
            }
            if (replicatedListeners.contains(l)) continue;
            replicatedListeners.add(l);
        }
    }

    private static LinkedBlockingQueue<org.nd4j.linalg.dataset.api.DataSet> $default$queue() {
        return new LinkedBlockingQueue<org.nd4j.linalg.dataset.api.DataSet>(1);
    }

    private static LinkedBlockingQueue<MultiDataSet> $default$queueMDS() {
        return new LinkedBlockingQueue<MultiDataSet>(1);
    }

    private static AtomicInteger $default$running() {
        return new AtomicInteger(0);
    }

    private static AtomicBoolean $default$shouldUpdate() {
        return new AtomicBoolean(false);
    }

    private static AtomicBoolean $default$shouldStop() {
        return new AtomicBoolean(false);
    }

    private static boolean $default$useMDS() {
        return false;
    }

    private static boolean $default$onRootModel() {
        return false;
    }

    private static AtomicLong $default$lastEtlTime() {
        return new AtomicLong(0L);
    }

    private static AtomicBoolean $default$nullMode() {
        return new AtomicBoolean(false);
    }

    private static AtomicBoolean $default$isStopped() {
        return new AtomicBoolean(false);
    }

    public static DefaultTrainerBuilder builder() {
        return new DefaultTrainerBuilder();
    }

    public DefaultTrainer() {
    }

    @ConstructorProperties(value={"replicatedModel", "queue", "queueMDS", "running", "shouldUpdate", "shouldStop", "thrownException", "useMDS", "onRootModel", "lastEtlTime", "nullMode", "nullDataSet", "isStopped", "parallelWrapper", "workspaceMode", "averagingFrequency", "threadId", "originalModel"})
    public DefaultTrainer(Model replicatedModel, LinkedBlockingQueue<org.nd4j.linalg.dataset.api.DataSet> queue, LinkedBlockingQueue<MultiDataSet> queueMDS, AtomicInteger running, AtomicBoolean shouldUpdate, AtomicBoolean shouldStop, Exception thrownException, boolean useMDS, boolean onRootModel, AtomicLong lastEtlTime, AtomicBoolean nullMode, org.nd4j.linalg.dataset.api.DataSet nullDataSet, AtomicBoolean isStopped, ParallelWrapper parallelWrapper, WorkspaceMode workspaceMode, int averagingFrequency, int threadId, Model originalModel) {
        this.replicatedModel = replicatedModel;
        this.queue = queue;
        this.queueMDS = queueMDS;
        this.running = running;
        this.shouldUpdate = shouldUpdate;
        this.shouldStop = shouldStop;
        this.thrownException = thrownException;
        this.useMDS = useMDS;
        this.onRootModel = onRootModel;
        this.lastEtlTime = lastEtlTime;
        this.nullMode = nullMode;
        this.nullDataSet = nullDataSet;
        this.isStopped = isStopped;
        this.parallelWrapper = parallelWrapper;
        this.workspaceMode = workspaceMode;
        this.averagingFrequency = averagingFrequency;
        this.threadId = threadId;
        this.originalModel = originalModel;
    }

    public static class DefaultTrainerBuilder {
        private Model replicatedModel;
        private boolean queue$set;
        private LinkedBlockingQueue<org.nd4j.linalg.dataset.api.DataSet> queue;
        private boolean queueMDS$set;
        private LinkedBlockingQueue<MultiDataSet> queueMDS;
        private boolean running$set;
        private AtomicInteger running;
        private boolean shouldUpdate$set;
        private AtomicBoolean shouldUpdate;
        private boolean shouldStop$set;
        private AtomicBoolean shouldStop;
        private Exception thrownException;
        private boolean useMDS$set;
        private boolean useMDS;
        private boolean onRootModel$set;
        private boolean onRootModel;
        private boolean lastEtlTime$set;
        private AtomicLong lastEtlTime;
        private boolean nullMode$set;
        private AtomicBoolean nullMode;
        private org.nd4j.linalg.dataset.api.DataSet nullDataSet;
        private boolean isStopped$set;
        private AtomicBoolean isStopped;
        private ParallelWrapper parallelWrapper;
        private WorkspaceMode workspaceMode;
        private int averagingFrequency;
        private int threadId;
        private Model originalModel;

        public DefaultTrainerBuilder replicatedModel(Model replicatedModel) {
            this.replicatedModel = replicatedModel;
            return this;
        }

        public DefaultTrainerBuilder queue(LinkedBlockingQueue<org.nd4j.linalg.dataset.api.DataSet> queue) {
            this.queue = queue;
            this.queue$set = true;
            return this;
        }

        public DefaultTrainerBuilder queueMDS(LinkedBlockingQueue<MultiDataSet> queueMDS) {
            this.queueMDS = queueMDS;
            this.queueMDS$set = true;
            return this;
        }

        public DefaultTrainerBuilder running(AtomicInteger running) {
            this.running = running;
            this.running$set = true;
            return this;
        }

        public DefaultTrainerBuilder shouldUpdate(AtomicBoolean shouldUpdate) {
            this.shouldUpdate = shouldUpdate;
            this.shouldUpdate$set = true;
            return this;
        }

        public DefaultTrainerBuilder shouldStop(AtomicBoolean shouldStop) {
            this.shouldStop = shouldStop;
            this.shouldStop$set = true;
            return this;
        }

        public DefaultTrainerBuilder thrownException(Exception thrownException) {
            this.thrownException = thrownException;
            return this;
        }

        public DefaultTrainerBuilder useMDS(boolean useMDS) {
            this.useMDS = useMDS;
            this.useMDS$set = true;
            return this;
        }

        public DefaultTrainerBuilder onRootModel(boolean onRootModel) {
            this.onRootModel = onRootModel;
            this.onRootModel$set = true;
            return this;
        }

        public DefaultTrainerBuilder lastEtlTime(AtomicLong lastEtlTime) {
            this.lastEtlTime = lastEtlTime;
            this.lastEtlTime$set = true;
            return this;
        }

        public DefaultTrainerBuilder nullMode(AtomicBoolean nullMode) {
            this.nullMode = nullMode;
            this.nullMode$set = true;
            return this;
        }

        public DefaultTrainerBuilder nullDataSet(org.nd4j.linalg.dataset.api.DataSet nullDataSet) {
            this.nullDataSet = nullDataSet;
            return this;
        }

        public DefaultTrainerBuilder isStopped(AtomicBoolean isStopped) {
            this.isStopped = isStopped;
            this.isStopped$set = true;
            return this;
        }

        public DefaultTrainerBuilder parallelWrapper(ParallelWrapper parallelWrapper) {
            this.parallelWrapper = parallelWrapper;
            return this;
        }

        public DefaultTrainerBuilder workspaceMode(WorkspaceMode workspaceMode) {
            this.workspaceMode = workspaceMode;
            return this;
        }

        public DefaultTrainerBuilder averagingFrequency(int averagingFrequency) {
            this.averagingFrequency = averagingFrequency;
            return this;
        }

        public DefaultTrainerBuilder threadId(int threadId) {
            this.threadId = threadId;
            return this;
        }

        public DefaultTrainerBuilder originalModel(Model originalModel) {
            this.originalModel = originalModel;
            return this;
        }

        public DefaultTrainer build() {
            return new DefaultTrainer(this.replicatedModel, this.queue$set ? this.queue : DefaultTrainer.$default$queue(), this.queueMDS$set ? this.queueMDS : DefaultTrainer.$default$queueMDS(), this.running$set ? this.running : DefaultTrainer.$default$running(), this.shouldUpdate$set ? this.shouldUpdate : DefaultTrainer.$default$shouldUpdate(), this.shouldStop$set ? this.shouldStop : DefaultTrainer.$default$shouldStop(), this.thrownException, this.useMDS$set ? this.useMDS : DefaultTrainer.$default$useMDS(), this.onRootModel$set ? this.onRootModel : DefaultTrainer.$default$onRootModel(), this.lastEtlTime$set ? this.lastEtlTime : DefaultTrainer.$default$lastEtlTime(), this.nullMode$set ? this.nullMode : DefaultTrainer.$default$nullMode(), this.nullDataSet, this.isStopped$set ? this.isStopped : DefaultTrainer.$default$isStopped(), this.parallelWrapper, this.workspaceMode, this.averagingFrequency, this.threadId, this.originalModel);
        }

        public String toString() {
            return "DefaultTrainer.DefaultTrainerBuilder(replicatedModel=" + this.replicatedModel + ", queue=" + this.queue + ", queueMDS=" + this.queueMDS + ", running=" + this.running + ", shouldUpdate=" + this.shouldUpdate + ", shouldStop=" + this.shouldStop + ", thrownException=" + this.thrownException + ", useMDS=" + this.useMDS + ", onRootModel=" + this.onRootModel + ", lastEtlTime=" + this.lastEtlTime + ", nullMode=" + this.nullMode + ", nullDataSet=" + this.nullDataSet + ", isStopped=" + this.isStopped + ", parallelWrapper=" + this.parallelWrapper + ", workspaceMode=" + this.workspaceMode + ", averagingFrequency=" + this.averagingFrequency + ", threadId=" + this.threadId + ", originalModel=" + this.originalModel + ")";
        }
    }
}

