/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.parallelism;

import java.util.Observer;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import lombok.NonNull;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.parallelism.inference.InferenceMode;
import org.deeplearning4j.parallelism.inference.InferenceObservable;
import org.deeplearning4j.parallelism.inference.observers.BasicInferenceObservable;
import org.deeplearning4j.parallelism.inference.observers.BasicInferenceObserver;
import org.deeplearning4j.parallelism.inference.observers.BatchedInferenceObservable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ParallelInference {
    private static final Logger log = LoggerFactory.getLogger(ParallelInference.class);
    private Model model;
    private long nanos;
    private int workers;
    private int batchLimit;
    private InferenceMode inferenceMode;
    private int queueLimit;
    private BlockingQueue<InferenceObservable> observables;
    private final Object locker = new Object();
    private InferenceWorker[] zoo;
    private ObservablesProvider provider;
    public static final int DEFAULT_NUM_WORKERS = Nd4j.getAffinityManager().getNumberOfDevices();
    public static final int DEFAULT_BATCH_LIMIT = 32;
    public static final InferenceMode DEFAULT_INFERENCE_MODE = InferenceMode.BATCHED;
    public static final int DEFAULT_QUEUE_LIMIT = 64;

    protected ParallelInference() {
    }

    protected void init() {
        this.observables = new LinkedBlockingQueue<InferenceObservable>(this.queueLimit);
        int numDevices = Nd4j.getAffinityManager().getNumberOfDevices();
        int currentDevice = Nd4j.getAffinityManager().getDeviceForCurrentThread();
        AtomicBoolean assignedRoot = new AtomicBoolean(false);
        this.zoo = new InferenceWorker[this.workers];
        for (int i = 0; i < this.workers; ++i) {
            int cDevice = i % numDevices;
            boolean cRoot = !assignedRoot.get() && cDevice == currentDevice;
            assignedRoot.compareAndSet(false, cRoot);
            this.zoo[i] = new InferenceWorker(i, this.model, this.observables, cRoot);
            Nd4j.getAffinityManager().attachThreadToDevice((Thread)this.zoo[i], Integer.valueOf(cDevice));
            this.zoo[i].setDaemon(true);
            this.zoo[i].start();
        }
        if (this.inferenceMode == InferenceMode.BATCHED) {
            log.info("Initializing ObservablesProvider...");
            this.provider = new ObservablesProvider(this.nanos, this.batchLimit, this.observables);
        }
    }

    protected long getWorkerCounter(int workerIdx) {
        return this.zoo[workerIdx].getCounterValue();
    }

    public INDArray output(double[] input) {
        return this.output(Nd4j.create((double[])input));
    }

    public INDArray output(float[] input) {
        return this.output(Nd4j.create((float[])input));
    }

    public INDArray output(INDArray input) {
        return this.output(new INDArray[]{input})[0];
    }

    public INDArray output(DataSet dataSet) {
        return this.output(dataSet.getFeatureMatrix());
    }

    public INDArray[] output(INDArray ... input) {
        InferenceObservable observable;
        BasicInferenceObserver observer = new BasicInferenceObserver();
        if (this.inferenceMode == InferenceMode.SEQUENTIAL) {
            observable = new BasicInferenceObservable(input);
            observable.addObserver(observer);
            try {
                this.observables.put(observable);
            }
            catch (InterruptedException e) {
                throw new RuntimeException(e);
            }
        } else {
            observable = this.provider.setInput(observer, input);
        }
        try {
            observer.waitTillDone();
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
        return observable.getOutput();
    }

    protected static class ObservablesProvider {
        private BlockingQueue<InferenceObservable> targetQueue;
        private long nanos;
        private int batchLimit;
        private volatile BatchedInferenceObservable currentObservable;
        private final Object locker = new Object();

        protected ObservablesProvider(long nanos, int batchLimit, @NonNull BlockingQueue<InferenceObservable> queue) {
            if (queue == null) {
                throw new NullPointerException("queue");
            }
            this.targetQueue = queue;
            this.nanos = nanos;
            this.batchLimit = batchLimit;
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        protected InferenceObservable setInput(@NonNull Observer observer, INDArray ... input) {
            if (observer == null) {
                throw new NullPointerException("observer");
            }
            Object object = this.locker;
            synchronized (object) {
                boolean isNew = false;
                if (this.currentObservable == null || this.currentObservable.getCounter() >= this.batchLimit || this.currentObservable.isLocked()) {
                    isNew = true;
                    this.currentObservable = new BatchedInferenceObservable();
                }
                this.currentObservable.setInput(input);
                this.currentObservable.addObserver(observer);
                try {
                    if (isNew) {
                        this.targetQueue.put(this.currentObservable);
                    }
                }
                catch (InterruptedException e) {
                    throw new RuntimeException(e);
                }
                return this.currentObservable;
            }
        }
    }

    private class InferenceWorker
    extends Thread
    implements Runnable {
        private BlockingQueue<InferenceObservable> inputQueue;
        private AtomicBoolean shouldWork = new AtomicBoolean(true);
        private AtomicBoolean isStopped = new AtomicBoolean(false);
        private Model protoModel;
        private Model replicatedModel;
        private AtomicLong counter = new AtomicLong(0L);
        private boolean rootDevice;

        private InferenceWorker(@NonNull int id, @NonNull Model model, BlockingQueue inputQueue, boolean rootDevice) {
            if (model == null) {
                throw new NullPointerException("model");
            }
            if (inputQueue == null) {
                throw new NullPointerException("inputQueue");
            }
            this.inputQueue = inputQueue;
            this.protoModel = model;
            this.rootDevice = rootDevice;
            this.setDaemon(true);
            this.setName("InferenceThread-" + id);
        }

        protected long getCounterValue() {
            return this.counter.get();
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        @Override
        public void run() {
            try {
                Object object;
                if (this.protoModel instanceof ComputationGraph) {
                    if (!this.rootDevice) {
                        this.replicatedModel = new ComputationGraph(ComputationGraphConfiguration.fromJson((String)((ComputationGraph)this.protoModel).getConfiguration().toJson()));
                        this.replicatedModel.init();
                        object = ParallelInference.this.locker;
                        synchronized (object) {
                            this.replicatedModel.setParams(this.protoModel.params().unsafeDuplication(true));
                            Nd4j.getExecutioner().commit();
                        }
                    } else {
                        this.replicatedModel = this.protoModel;
                    }
                } else if (this.protoModel instanceof MultiLayerNetwork) {
                    if (!this.rootDevice) {
                        this.replicatedModel = new MultiLayerNetwork(MultiLayerConfiguration.fromJson((String)((MultiLayerNetwork)this.protoModel).getLayerWiseConfigurations().toJson()));
                        this.replicatedModel.init();
                        object = ParallelInference.this.locker;
                        synchronized (object) {
                            this.replicatedModel.setParams(this.protoModel.params().unsafeDuplication(true));
                            Nd4j.getExecutioner().commit();
                        }
                    } else {
                        this.replicatedModel = this.protoModel;
                    }
                }
                while (this.shouldWork.get()) {
                    INDArray output;
                    InferenceObservable request = this.inputQueue.take();
                    if (request == null) continue;
                    this.counter.incrementAndGet();
                    if (this.replicatedModel instanceof ComputationGraph) {
                        output = ((ComputationGraph)this.replicatedModel).output(false, request.getInput());
                        request.setOutput((INDArray[])output);
                        continue;
                    }
                    if (!(this.replicatedModel instanceof MultiLayerNetwork)) continue;
                    output = ((MultiLayerNetwork)this.replicatedModel).output(request.getInput()[0]);
                    request.setOutput(output);
                }
            }
            catch (InterruptedException request) {
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
            this.isStopped.set(true);
        }

        protected void shutdown() {
            this.shouldWork.set(false);
            while (!this.isStopped.get()) {
            }
        }
    }

    public static class Builder {
        private Model model;
        private int workers = DEFAULT_NUM_WORKERS;
        private int batchLimit = 32;
        private InferenceMode inferenceMode = DEFAULT_INFERENCE_MODE;
        private int queueLimit = 64;

        public Builder(@NonNull Model model) {
            if (model == null) {
                throw new NullPointerException("model");
            }
            this.model = model;
        }

        public Builder inferenceMode(@NonNull InferenceMode inferenceMode) {
            if (inferenceMode == null) {
                throw new NullPointerException("inferenceMode");
            }
            this.inferenceMode = inferenceMode;
            return this;
        }

        public Builder workers(int workers) {
            if (workers < 1) {
                throw new IllegalStateException("Workers should be positive value");
            }
            this.workers = workers;
            return this;
        }

        public Builder batchLimit(int limit) {
            if (limit < 1) {
                throw new IllegalStateException("Batch limit should be positive value");
            }
            this.batchLimit = limit;
            return this;
        }

        public Builder queueLimit(int limit) {
            if (limit < 1) {
                throw new IllegalStateException("Queue limit should be positive value");
            }
            this.queueLimit = limit;
            return this;
        }

        public ParallelInference build() {
            ParallelInference inference = new ParallelInference();
            inference.batchLimit = this.batchLimit;
            inference.queueLimit = this.queueLimit;
            inference.inferenceMode = this.inferenceMode;
            inference.model = this.model;
            inference.workers = this.workers;
            inference.init();
            return inference;
        }
    }
}

