/*
 * Decompiled with CFR 0.152.
 */
package org.datavec.api.writable;

import java.io.DataInput;
import java.io.DataInputStream;
import java.io.DataOutput;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.Arrays;
import javax.annotation.Nonnull;
import org.datavec.api.io.WritableComparable;
import org.datavec.api.util.MathUtils;
import org.datavec.api.util.ndarray.DataInputWrapperStream;
import org.datavec.api.util.ndarray.DataOutputWrapperStream;
import org.datavec.api.writable.ArrayWritable;
import org.datavec.api.writable.WritableType;
import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

public class NDArrayWritable
extends ArrayWritable
implements WritableComparable {
    public static final byte NDARRAY_SER_VERSION_HEADER_NULL = 0;
    public static final byte NDARRAY_SER_VERSION_HEADER = 1;
    private INDArray array = null;
    private Integer hash = null;

    public NDArrayWritable() {
    }

    public NDArrayWritable(INDArray array) {
        this.set(array);
    }

    @Override
    public void readFields(DataInput in) throws IOException {
        DataInputStream dis = new DataInputStream(new DataInputWrapperStream(in));
        byte header = dis.readByte();
        if (header != 1 && header != 0) {
            throw new IllegalStateException("Unexpected NDArrayWritable version header - stream corrupt?");
        }
        if (header == 0) {
            this.array = null;
            return;
        }
        this.array = Nd4j.read((DataInputStream)dis);
        this.hash = null;
    }

    @Override
    public void writeType(DataOutput out) throws IOException {
        out.writeShort(WritableType.NDArray.typeIdx());
    }

    @Override
    public WritableType getType() {
        return WritableType.NDArray;
    }

    @Override
    public void write(DataOutput out) throws IOException {
        if (this.array == null) {
            out.write(0);
            return;
        }
        INDArray toWrite = this.array.isView() ? this.array.dup() : this.array;
        out.write(1);
        Nd4j.write((INDArray)toWrite, (DataOutputStream)new DataOutputStream(new DataOutputWrapperStream(out)));
    }

    public void set(INDArray array) {
        this.array = array;
        this.hash = null;
    }

    public INDArray get() {
        return this.array;
    }

    public boolean equals(Object o) {
        if (!(o instanceof NDArrayWritable)) {
            return false;
        }
        INDArray io = ((NDArrayWritable)o).get();
        if (this.array == null && io != null || this.array != null && io == null) {
            return false;
        }
        if (this.array == null) {
            return true;
        }
        return this.array.equalsWithEps((Object)io, 0.0);
    }

    public int hashCode() {
        if (this.hash != null) {
            return this.hash;
        }
        if (this.array == null) {
            this.hash = 0;
            return this.hash;
        }
        int hash = Arrays.hashCode(this.array.shape());
        int length = this.array.length();
        NdIndexIterator iter = new NdIndexIterator('c', this.array.shape());
        for (int i = 0; i < length; ++i) {
            hash ^= MathUtils.hashCode(this.array.getDouble(iter.next()));
        }
        this.hash = hash;
        return hash;
    }

    @Override
    public int compareTo(@Nonnull Object o) {
        NDArrayWritable other = (NDArrayWritable)o;
        if (this.array == null) {
            if (other.array == null) {
                return 0;
            }
            return -1;
        }
        if (other.array == null) {
            return 1;
        }
        if (this.array.rank() != other.array.rank()) {
            return Integer.compare(this.array.rank(), other.array.rank());
        }
        if (this.array.length() != other.array.length()) {
            return Long.compare(this.array.length(), other.array.length());
        }
        for (int i = 0; i < this.array.rank(); ++i) {
            if (Integer.compare(this.array.size(i), other.array.size(i)) == 0) continue;
            return Integer.compare(this.array.size(i), other.array.size(i));
        }
        NdIndexIterator iter = new NdIndexIterator('c', this.array.shape());
        while (iter.hasNext()) {
            double d2;
            int[] nextPos = iter.next();
            double d1 = this.array.getDouble(nextPos);
            if (Double.compare(d1, d2 = other.array.getDouble(nextPos)) == 0) continue;
            return Double.compare(d1, d2);
        }
        return 0;
    }

    public String toString() {
        return this.array.toString();
    }

    @Override
    public long length() {
        return this.array.data().length();
    }

    @Override
    public double getDouble(long i) {
        return this.array.data().getDouble(i);
    }

    @Override
    public float getFloat(long i) {
        return this.array.data().getFloat(i);
    }

    @Override
    public int getInt(long i) {
        return this.array.data().getInt(i);
    }

    @Override
    public long getLong(long i) {
        return (long)this.array.data().getDouble(i);
    }
}

