/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.parameterserver.distributed.messages.aggregations;

import java.io.Serializable;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import lombok.NonNull;
import org.agrona.concurrent.UnsafeBuffer;
import org.apache.commons.lang3.SerializationUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.parameterserver.distributed.messages.BaseVoidMessage;
import org.nd4j.parameterserver.distributed.messages.VoidAggregation;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class BaseAggregation
extends BaseVoidMessage
implements VoidAggregation,
Serializable {
    private static final Logger log = LoggerFactory.getLogger(BaseAggregation.class);
    protected short aggregationType = (short)-1;
    protected short aggregationWidth;
    protected int numberOfElements;
    protected short shardIndex;
    protected INDArray payload;
    protected transient AtomicInteger chunksCounter = new AtomicInteger(1);
    protected transient Map<Short, INDArray> chunks = new ConcurrentHashMap<Short, INDArray>();

    protected BaseAggregation() {
    }

    protected BaseAggregation(long taskId, short aggregationWidth, short shardIndex) {
        this();
        this.aggregationWidth = aggregationWidth;
        this.taskId = taskId;
        this.shardIndex = shardIndex;
    }

    @Override
    public void setShardIndex(short shardIndex) {
        if (shardIndex == this.shardIndex) {
            return;
        }
        this.chunks.remove(this.shardIndex);
        this.chunks.put(shardIndex, this.payload);
        this.shardIndex = shardIndex;
    }

    protected void addToChunks(INDArray array) {
        this.chunks.put(this.shardIndex, array);
    }

    @Override
    public void accumulateAggregation(@NonNull VoidAggregation aggregation) {
        if (aggregation == null) {
            throw new NullPointerException("aggregation");
        }
        if (aggregation.getAggregationType() != this.getAggregationType()) {
            throw new ND4JIllegalStateException("Trying to aggregate different aggregations!");
        }
        if (this.getShardIndex() == aggregation.getShardIndex()) {
            return;
        }
        if (this.chunks.get(aggregation.getShardIndex()) == null) {
            this.chunksCounter.incrementAndGet();
        }
        this.chunks.put(aggregation.getShardIndex(), aggregation.getPayload());
    }

    @Override
    public INDArray getAccumulatedResult() {
        if (this.aggregationWidth == 1) {
            return this.chunks.get((short)0);
        }
        return Nd4j.hstack(this.chunks.values());
    }

    @Override
    public int getMissingChunks() {
        return this.aggregationWidth - this.chunksCounter.get();
    }

    @Override
    public int getMessageType() {
        return 21;
    }

    @Override
    public byte[] asBytes() {
        return SerializationUtils.serialize((Serializable)this);
    }

    @Override
    public UnsafeBuffer asUnsafeBuffer() {
        return new UnsafeBuffer(this.asBytes());
    }

    @Override
    public short getTargetId() {
        return -1;
    }

    @Override
    public short getAggregationType() {
        return this.aggregationType;
    }

    public void setAggregationType(short aggregationType) {
        this.aggregationType = aggregationType;
    }

    public short getAggregationWidth() {
        return this.aggregationWidth;
    }

    public void setAggregationWidth(short aggregationWidth) {
        this.aggregationWidth = aggregationWidth;
    }

    public int getNumberOfElements() {
        return this.numberOfElements;
    }

    public void setNumberOfElements(int numberOfElements) {
        this.numberOfElements = numberOfElements;
    }

    @Override
    public short getShardIndex() {
        return this.shardIndex;
    }

    @Override
    public INDArray getPayload() {
        return this.payload;
    }

    public void setPayload(INDArray payload) {
        this.payload = payload;
    }

    public AtomicInteger getChunksCounter() {
        return this.chunksCounter;
    }

    public Map<Short, INDArray> getChunks() {
        return this.chunks;
    }
}

