/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.parallelism.inference.observers;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import org.deeplearning4j.parallelism.inference.InferenceObservable;
import org.deeplearning4j.parallelism.inference.observers.BasicInferenceObservable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class BatchedInferenceObservable
extends BasicInferenceObservable
implements InferenceObservable {
    private static final Logger log = LoggerFactory.getLogger(BatchedInferenceObservable.class);
    private List<INDArray[]> inputs = new ArrayList<INDArray[]>();
    private List<INDArray[]> outputs = new ArrayList<INDArray[]>();
    private AtomicInteger counter = new AtomicInteger(0);
    private ThreadLocal<Integer> position = new ThreadLocal();
    private final Object locker = new Object();
    private ReentrantReadWriteLock realLocker = new ReentrantReadWriteLock();
    private AtomicBoolean isLocked = new AtomicBoolean(false);
    private AtomicBoolean isReadLocked = new AtomicBoolean(false);

    public BatchedInferenceObservable() {
        super(new INDArray[0]);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void setInput(INDArray ... input) {
        Object object = this.locker;
        synchronized (object) {
            this.inputs.add(input);
            this.position.set(this.counter.getAndIncrement());
            if (this.isReadLocked.get()) {
                this.realLocker.readLock().unlock();
            }
        }
    }

    @Override
    public INDArray[] getInput() {
        this.realLocker.writeLock().lock();
        this.isLocked.set(true);
        if (this.counter.get() > 1) {
            INDArray[] result = new INDArray[this.inputs.get(0).length];
            for (int i = 0; i < result.length; ++i) {
                ArrayList<INDArray> examples = new ArrayList<INDArray>();
                for (int e = 0; e < this.inputs.size(); ++e) {
                    examples.add(this.inputs.get(e)[i]);
                }
                result[i] = Nd4j.pile(examples);
            }
            this.realLocker.writeLock().unlock();
            return result;
        }
        this.realLocker.writeLock().unlock();
        return this.inputs.get(0);
    }

    @Override
    public void setOutput(INDArray ... output) {
        if (this.counter.get() > 1) {
            for (int i = 0; i < this.counter.get(); ++i) {
                this.outputs.add(new INDArray[output.length]);
            }
            int cnt = 0;
            for (INDArray array : output) {
                int[] dimensions = new int[array.rank() - 1];
                for (int i = 1; i < array.rank(); ++i) {
                    dimensions[i - 1] = i;
                }
                INDArray[] split = Nd4j.tear((INDArray)array, (int[])dimensions);
                if (split.length != this.counter.get()) {
                    throw new ND4JIllegalStateException("Number of splits [" + split.length + "] doesn't match number of queries [" + this.counter.get() + "]");
                }
                for (int e = 0; e < this.counter.get(); ++e) {
                    this.outputs.get((int)e)[cnt] = split[e];
                }
                ++cnt;
            }
        } else {
            this.outputs.add(output);
        }
        this.setChanged();
        this.notifyObservers();
    }

    protected List<INDArray[]> getOutputs() {
        return this.outputs;
    }

    protected void setCounter(int value) {
        this.counter.set(value);
    }

    public void setPosition(int pos) {
        this.position.set(pos);
    }

    public int getCounter() {
        return this.counter.get();
    }

    public boolean isLocked() {
        boolean result;
        boolean lck = !this.realLocker.readLock().tryLock();
        boolean bl = result = lck || this.isLocked.get();
        if (!result) {
            this.isReadLocked.set(true);
        }
        return result;
    }

    @Override
    public INDArray[] getOutput() {
        return this.outputs.get(this.position.get());
    }
}

