/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.scaleout.actor.runner;

import akka.actor.ActorRef;
import akka.actor.ActorSelection;
import akka.actor.ActorSystem;
import akka.actor.Address;
import akka.actor.AddressFromURIString;
import akka.actor.PoisonPill;
import akka.actor.Props;
import akka.cluster.Cluster;
import akka.contrib.pattern.ClusterClient;
import akka.contrib.pattern.ClusterSingletonManager;
import akka.contrib.pattern.DistributedPubSubExtension;
import akka.contrib.pattern.DistributedPubSubMediator;
import akka.routing.RoundRobinPool;
import java.io.Serializable;
import java.lang.reflect.Constructor;
import java.net.URI;
import java.util.HashSet;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import org.canova.api.conf.Configuration;
import org.deeplearning4j.nn.conf.DeepLearningConfigurable;
import org.deeplearning4j.scaleout.actor.core.ClusterListener;
import org.deeplearning4j.scaleout.actor.core.ModelSaver;
import org.deeplearning4j.scaleout.actor.core.actor.BatchActor;
import org.deeplearning4j.scaleout.actor.core.actor.MasterActor;
import org.deeplearning4j.scaleout.actor.core.actor.ModelSavingActor;
import org.deeplearning4j.scaleout.actor.core.actor.WorkerActor;
import org.deeplearning4j.scaleout.actor.util.ActorRefUtils;
import org.deeplearning4j.scaleout.aggregator.INDArrayAggregator;
import org.deeplearning4j.scaleout.aggregator.JobAggregator;
import org.deeplearning4j.scaleout.api.statetracker.StateTracker;
import org.deeplearning4j.scaleout.api.workrouter.WorkRouter;
import org.deeplearning4j.scaleout.job.JobIterator;
import org.deeplearning4j.scaleout.messages.MoreWorkMessage;
import org.deeplearning4j.scaleout.perform.WorkerPerformer;
import org.deeplearning4j.scaleout.perform.WorkerPerformerFactory;
import org.deeplearning4j.scaleout.statetracker.hazelcast.HazelCastStateTracker;
import org.deeplearning4j.scaleout.workrouter.IterativeReduceWorkRouter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.concurrent.ExecutionContext;
import scala.concurrent.duration.Duration;

public class DeepLearning4jDistributed
implements DeepLearningConfigurable,
Serializable {
    private static final long serialVersionUID = -4385335922485305364L;
    private transient ActorSystem system;
    private ActorRef mediator;
    private static final Logger log = LoggerFactory.getLogger(DeepLearning4jDistributed.class);
    private static String systemName = "ClusterSystem";
    private String type = "master";
    private Address masterAddress;
    private JobIterator iter;
    protected ActorRef masterActor;
    protected ModelSaver modelSaver;
    private transient ScheduledExecutorService exec;
    private transient StateTracker stateTracker;
    private int stateTrackerPort = -1;
    private String masterHost;
    private transient WorkRouter workRouter;

    public DeepLearning4jDistributed(String type, JobIterator iter) {
        this.type = type;
        this.iter = iter;
    }

    public DeepLearning4jDistributed(JobIterator iter, StateTracker stateTracker) {
        this("master", iter);
        this.stateTracker = stateTracker;
    }

    public DeepLearning4jDistributed(JobIterator iter) {
        this("master", iter);
    }

    public DeepLearning4jDistributed(String type, String address) {
        this.type = type;
        URI u = URI.create(address);
        this.masterAddress = Address.apply((String)u.getScheme(), (String)u.getUserInfo(), (String)u.getHost(), (int)u.getPort());
    }

    public DeepLearning4jDistributed() {
    }

    public Address startBackend(Address joinAddress, Configuration c, JobIterator iter, StateTracker stateTracker) {
        ActorRefUtils.addShutDownForSystem(this.system);
        this.system.actorOf(Props.create(ClusterListener.class, (Object[])new Object[0]));
        try {
            Class<?> routerClazz = Class.forName(c.get("org.deeplearning4j.scaleout.api.workrouter.workrouter", IterativeReduceWorkRouter.class.getName()));
            Constructor<?> constructor = routerClazz.getConstructor(StateTracker.class);
            this.workRouter = (WorkRouter)constructor.newInstance(stateTracker);
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
        this.workRouter.setup(c);
        ActorRef batchActor = this.system.actorOf(Props.create(BatchActor.class, (Object[])new Object[]{iter, stateTracker, c, this.workRouter}), "batch");
        log.info("Started batch actor");
        Props masterProps = Props.create(MasterActor.class, (Object[])new Object[]{c, batchActor, stateTracker, this.workRouter});
        Address realJoinAddress = joinAddress == null ? Cluster.get((ActorSystem)this.system).selfAddress() : joinAddress;
        c.set("org.deeplearning4j.scaleout.masterurl", realJoinAddress.toString());
        if (this.exec == null) {
            this.exec = Executors.newScheduledThreadPool(2);
        }
        Cluster cluster = Cluster.get((ActorSystem)this.system);
        cluster.join(realJoinAddress);
        this.exec.schedule(new Runnable(){

            @Override
            public void run() {
                Cluster cluster = Cluster.get((ActorSystem)DeepLearning4jDistributed.this.system);
                cluster.publishCurrentClusterState();
            }
        }, 10L, TimeUnit.SECONDS);
        this.masterActor = this.system.actorOf(ClusterSingletonManager.defaultProps((Props)masterProps, (String)"master", (Object)PoisonPill.getInstance(), (String)"master"));
        log.info("Started master with address " + realJoinAddress.toString());
        c.set("org.deeplearning4j.scaleout.masterpath", ActorRefUtils.absPath(this.masterActor, this.system));
        log.info("Set master abs path " + c.get("org.deeplearning4j.scaleout.masterpath"));
        return realJoinAddress;
    }

    public void setup(Configuration conf) {
        this.system = ActorSystem.create((String)systemName);
        ActorRefUtils.addShutDownForSystem(this.system);
        this.mediator = DistributedPubSubExtension.get((ActorSystem)this.system).mediator();
        if (this.type.equals("master")) {
            if (this.iter == null) {
                throw new IllegalStateException("Unable to initialize no dataset to iterate");
            }
            log.info("Starting master");
            try {
                if (this.stateTracker == null) {
                    this.stateTracker = this.stateTrackerPort > 0 ? new HazelCastStateTracker(this.stateTrackerPort) : new HazelCastStateTracker();
                }
                if (this.stateTracker.jobAggregator() == null) {
                    Class<?> clazz = Class.forName(conf.get("org.deeplearning4j.scaleout.aggregator", INDArrayAggregator.class.getName()));
                    JobAggregator agg = (JobAggregator)clazz.newInstance();
                    this.stateTracker.setJobAggregator(agg);
                }
                log.info("Started state tracker with connection string " + this.stateTracker.connectionString());
                this.masterAddress = this.startBackend(null, conf, this.iter, this.stateTracker);
            }
            catch (Exception e1) {
                Thread.currentThread().interrupt();
                throw new RuntimeException(e1);
            }
            log.info("Starting Save saver");
            if (this.modelSaver == null) {
                this.system.actorOf(Props.create(ModelSavingActor.class, (Object[])new Object[]{"model-saver", this.stateTracker}));
            } else {
                this.system.actorOf(Props.create(ModelSavingActor.class, (Object[])new Object[]{this.modelSaver, this.stateTracker}));
            }
            conf.set("org.deeplearning4j.scaleout.masterurl", this.getMasterAddress().toString());
            conf.set("org.deeplearning4j.scaleout.masterpath", ActorRefUtils.absPath(this.masterActor, this.system));
            conf.set("org.deeplearning4j.scaleout.statetracker.connectionstring", this.stateTracker.connectionString());
            ActorRefUtils.registerConfWithZooKeeper(conf, this.system);
            this.system.scheduler().schedule(Duration.create((long)1L, (TimeUnit)TimeUnit.MINUTES), Duration.create((long)1L, (TimeUnit)TimeUnit.MINUTES), new Runnable(){

                @Override
                public void run() {
                    if (!DeepLearning4jDistributed.this.system.isTerminated()) {
                        try {
                            log.info("Current cluster members " + Cluster.get((ActorSystem)DeepLearning4jDistributed.this.system).readView().members());
                        }
                        catch (Exception e) {
                            log.warn("Tried reading cluster members during shutdown");
                        }
                    }
                }
            }, (ExecutionContext)this.system.dispatcher());
        } else {
            log.info("Starting worker node");
            Address a = AddressFromURIString.parse((String)conf.get("org.deeplearning4j.scaleout.masterurl"));
            Configuration c = new Configuration(conf);
            Cluster cluster = Cluster.get((ActorSystem)this.system);
            cluster.join(a);
            try {
                String host = (String)a.host().get();
                if (host == null) {
                    throw new IllegalArgumentException("No host applyTransformToDestination for worker");
                }
                String connectionString = conf.get("org.deeplearning4j.scaleout.statetracker.connectionstring");
                if (connectionString.contains("0.0.0.0")) {
                    if (this.masterHost == null) {
                        throw new IllegalStateException("No master host specified and host discovery was lost due to improper setup on the master (related to hostname resolution) Please run the following command on your host: sudo hostname YOUR_HOST_NAME. This will make your hostname resolution work correctly on master.");
                    }
                    connectionString = connectionString.replace("0.0.0.0", this.masterHost);
                }
                log.info("Creating state tracker with connection string " + connectionString);
                if (this.stateTracker == null) {
                    this.stateTracker = new HazelCastStateTracker(connectionString);
                }
            }
            catch (Exception e1) {
                Thread.currentThread().interrupt();
                throw new RuntimeException(e1);
            }
            this.startWorker(c);
            this.system.scheduler().schedule(Duration.create((long)1L, (TimeUnit)TimeUnit.MINUTES), Duration.create((long)1L, (TimeUnit)TimeUnit.MINUTES), new Runnable(){

                @Override
                public void run() {
                    log.info("Current cluster members " + Cluster.get((ActorSystem)DeepLearning4jDistributed.this.system).readView().members());
                }
            }, (ExecutionContext)this.system.dispatcher());
            log.info("Setup worker nodes");
        }
        if (this.type.equals("master")) {
            this.stateTracker.startRestApi();
        } else if (this.stateTracker instanceof HazelCastStateTracker) {
            log.info("Not starting drop wizard; worker state detected");
        }
    }

    public void startWorker(Configuration conf) {
        Address contactAddress = AddressFromURIString.parse((String)conf.get("org.deeplearning4j.scaleout.masterurl"));
        this.system.actorOf(Props.create(ClusterListener.class, (Object[])new Object[0]));
        log.info("Attempting to join node " + contactAddress);
        log.info("Starting workers");
        HashSet<ActorSelection> initialContacts = new HashSet<ActorSelection>();
        initialContacts.add(this.system.actorSelection(contactAddress + "/user/"));
        RoundRobinPool pool = new RoundRobinPool(Runtime.getRuntime().availableProcessors());
        ActorRef clusterClient = this.system.actorOf(ClusterClient.defaultProps(initialContacts), "clusterClient");
        try {
            String host = (String)contactAddress.host().get();
            log.info("Connecting  to host " + host);
            int workers = this.stateTracker.numWorkers();
            if (workers <= 1) {
                throw new IllegalStateException("Did not properly connect to cluster");
            }
            log.info("Joining cluster of size " + workers);
            Class<?> factoryClazz = Class.forName(conf.get("org.deeplearning4j.scaleout.perform.workerperformer"));
            WorkerPerformerFactory factory = (WorkerPerformerFactory)factoryClazz.newInstance();
            WorkerPerformer performer = factory.create(conf);
            Props p = pool.props(WorkerActor.propsFor(conf, this.stateTracker, performer));
            this.system.actorOf(p, "worker");
            Cluster cluster = Cluster.get((ActorSystem)this.system);
            cluster.join(contactAddress);
            log.info("Worker joining cluster of " + this.stateTracker.workers().size());
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public void train() {
        log.info("Publishing to results for training");
        log.info("Started pipeline");
        this.mediator.tell((Object)new DistributedPubSubMediator.Publish(MasterActor.MASTER, (Object)MoreWorkMessage.getInstance()), this.mediator);
        log.info("Published results");
        while (!this.stateTracker.isDone()) {
            log.info("State tracker not done...blocking");
            try {
                Thread.sleep(15000L);
            }
            catch (InterruptedException e) {
                Thread.currentThread().interrupt();
            }
        }
        this.shutdown();
    }

    public Address getMasterAddress() {
        return this.masterAddress;
    }

    public StateTracker getStateTracker() {
        return this.stateTracker;
    }

    public void setStateTracker(StateTracker stateTracker) {
        this.stateTracker = stateTracker;
    }

    public void shutdown() {
        try {
            this.system.shutdown();
        }
        catch (Exception e) {
            // empty catch block
        }
        try {
            if (this.stateTracker != null) {
                this.stateTracker.shutdown();
            }
        }
        catch (Exception exception) {
            // empty catch block
        }
    }

    public ModelSaver getModelSaver() {
        return this.modelSaver;
    }

    public void setModelSaver(ModelSaver modelSaver) {
        this.modelSaver = modelSaver;
    }

    public int getStateTrackerPort() {
        return this.stateTrackerPort;
    }

    public void setStateTrackerPort(int stateTrackerPort) {
        this.stateTrackerPort = stateTrackerPort;
    }

    public String getMasterHost() {
        return this.masterHost;
    }

    public void setMasterHost(String masterHost) {
        this.masterHost = masterHost;
    }
}

