/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.parameterserver.distributed.messages.intercom;

import org.apache.commons.math3.util.FastMath;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.parameterserver.distributed.enums.ExecutionMode;
import org.nd4j.parameterserver.distributed.logic.storage.WordVectorStorage;
import org.nd4j.parameterserver.distributed.messages.BaseVoidMessage;
import org.nd4j.parameterserver.distributed.messages.DistributedMessage;
import org.nd4j.parameterserver.distributed.messages.aggregations.InitializationAggregation;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DistributedInitializationMessage
extends BaseVoidMessage
implements DistributedMessage {
    private static final Logger log = LoggerFactory.getLogger(DistributedInitializationMessage.class);
    protected int vectorLength;
    protected int numWords;
    protected long seed;
    protected boolean useHs;
    protected boolean useNeg;
    protected int columnsPerShard;

    public DistributedInitializationMessage(int vectorLength, int numWords, long seed, boolean useHs, boolean useNeg, int columnsPerShard) {
        super(4);
        this.vectorLength = vectorLength;
        this.numWords = numWords;
        this.seed = seed;
        this.useHs = useHs;
        this.useNeg = useNeg;
        this.columnsPerShard = columnsPerShard;
    }

    @Override
    public void processMessage() {
        INDArray syn0 = this.storage.getArray(WordVectorStorage.SYN_0);
        INDArray syn1 = this.storage.getArray(WordVectorStorage.SYN_1);
        INDArray syn1Neg = this.storage.getArray(WordVectorStorage.SYN_1_NEGATIVE);
        INDArray expTable = this.storage.getArray(WordVectorStorage.EXP_TABLE);
        if (syn0 == null) {
            int modulo;
            log.info("sI_{} is starting initialization...", (Object)this.transport.getShardIndex());
            Nd4j.getRandom().setSeed(this.seed * (long)(this.shardIndex + 1));
            if (this.voidConfiguration.getExecutionMode() == ExecutionMode.AVERAGING) {
                this.columnsPerShard = this.vectorLength;
            } else if (this.voidConfiguration.getExecutionMode() == ExecutionMode.SHARDED && this.voidConfiguration.getNumberOfShards() - 1 == this.shardIndex && (modulo = this.vectorLength % this.voidConfiguration.getNumberOfShards()) != 0) {
                this.columnsPerShard += modulo;
                log.info("Got inequal split. using higher number of elements: {}", (Object)this.columnsPerShard);
            }
            int[] shardShape = new int[]{this.numWords, this.columnsPerShard};
            syn0 = Nd4j.rand((int[])shardShape, (long)99L).subi((Number)0.5).divi((Number)this.vectorLength);
            if (this.useHs) {
                syn1 = Nd4j.create((int[])shardShape, (char)'c');
            }
            if (this.useNeg) {
                syn1Neg = Nd4j.create((int[])shardShape, (char)'c');
            }
            expTable = this.initExpTable(100000);
            this.storage.setArray(WordVectorStorage.SYN_0, syn0);
            if (this.useHs) {
                this.storage.setArray(WordVectorStorage.SYN_1, syn1);
            }
            if (this.useNeg) {
                this.storage.setArray(WordVectorStorage.SYN_1_NEGATIVE, syn1Neg);
            }
            this.storage.setArray(WordVectorStorage.EXP_TABLE, expTable);
            InitializationAggregation ia = new InitializationAggregation((short)this.voidConfiguration.getNumberOfShards(), this.transport.getShardIndex());
            ia.setOriginatorId(this.originatorId);
            this.transport.sendMessage(ia);
        }
    }

    protected INDArray initExpTable(int tableWidth) {
        double[] expTable = new double[tableWidth];
        for (int i = 0; i < expTable.length; ++i) {
            double tmp = FastMath.exp((double)(((double)i / (double)expTable.length * 2.0 - 1.0) * 6.0));
            expTable[i] = tmp / (tmp + 1.0);
        }
        return Nd4j.create((double[])expTable);
    }

    public static DistributedInitializationMessageBuilder builder() {
        return new DistributedInitializationMessageBuilder();
    }

    public DistributedInitializationMessage() {
    }

    public int getVectorLength() {
        return this.vectorLength;
    }

    public int getNumWords() {
        return this.numWords;
    }

    public long getSeed() {
        return this.seed;
    }

    public boolean isUseHs() {
        return this.useHs;
    }

    public boolean isUseNeg() {
        return this.useNeg;
    }

    public int getColumnsPerShard() {
        return this.columnsPerShard;
    }

    public void setVectorLength(int vectorLength) {
        this.vectorLength = vectorLength;
    }

    public void setNumWords(int numWords) {
        this.numWords = numWords;
    }

    public void setSeed(long seed) {
        this.seed = seed;
    }

    public void setUseHs(boolean useHs) {
        this.useHs = useHs;
    }

    public void setUseNeg(boolean useNeg) {
        this.useNeg = useNeg;
    }

    public void setColumnsPerShard(int columnsPerShard) {
        this.columnsPerShard = columnsPerShard;
    }

    @Override
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof DistributedInitializationMessage)) {
            return false;
        }
        DistributedInitializationMessage other = (DistributedInitializationMessage)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (this.getVectorLength() != other.getVectorLength()) {
            return false;
        }
        if (this.getNumWords() != other.getNumWords()) {
            return false;
        }
        if (this.getSeed() != other.getSeed()) {
            return false;
        }
        if (this.isUseHs() != other.isUseHs()) {
            return false;
        }
        if (this.isUseNeg() != other.isUseNeg()) {
            return false;
        }
        return this.getColumnsPerShard() == other.getColumnsPerShard();
    }

    @Override
    protected boolean canEqual(Object other) {
        return other instanceof DistributedInitializationMessage;
    }

    @Override
    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        result = result * 59 + this.getVectorLength();
        result = result * 59 + this.getNumWords();
        long $seed = this.getSeed();
        result = result * 59 + (int)($seed >>> 32 ^ $seed);
        result = result * 59 + (this.isUseHs() ? 79 : 97);
        result = result * 59 + (this.isUseNeg() ? 79 : 97);
        result = result * 59 + this.getColumnsPerShard();
        return result;
    }

    @Override
    public String toString() {
        return "DistributedInitializationMessage(vectorLength=" + this.getVectorLength() + ", numWords=" + this.getNumWords() + ", seed=" + this.getSeed() + ", useHs=" + this.isUseHs() + ", useNeg=" + this.isUseNeg() + ", columnsPerShard=" + this.getColumnsPerShard() + ")";
    }

    public static class DistributedInitializationMessageBuilder {
        private int vectorLength;
        private int numWords;
        private long seed;
        private boolean useHs;
        private boolean useNeg;
        private int columnsPerShard;

        DistributedInitializationMessageBuilder() {
        }

        public DistributedInitializationMessageBuilder vectorLength(int vectorLength) {
            this.vectorLength = vectorLength;
            return this;
        }

        public DistributedInitializationMessageBuilder numWords(int numWords) {
            this.numWords = numWords;
            return this;
        }

        public DistributedInitializationMessageBuilder seed(long seed) {
            this.seed = seed;
            return this;
        }

        public DistributedInitializationMessageBuilder useHs(boolean useHs) {
            this.useHs = useHs;
            return this;
        }

        public DistributedInitializationMessageBuilder useNeg(boolean useNeg) {
            this.useNeg = useNeg;
            return this;
        }

        public DistributedInitializationMessageBuilder columnsPerShard(int columnsPerShard) {
            this.columnsPerShard = columnsPerShard;
            return this;
        }

        public DistributedInitializationMessage build() {
            return new DistributedInitializationMessage(this.vectorLength, this.numWords, this.seed, this.useHs, this.useNeg, this.columnsPerShard);
        }

        public String toString() {
            return "DistributedInitializationMessage.DistributedInitializationMessageBuilder(vectorLength=" + this.vectorLength + ", numWords=" + this.numWords + ", seed=" + this.seed + ", useHs=" + this.useHs + ", useNeg=" + this.useNeg + ", columnsPerShard=" + this.columnsPerShard + ")";
        }
    }
}

