/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.jcublas.compression;

import org.bytedeco.javacpp.IntPointer;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.compression.CompressedDataBuffer;
import org.nd4j.linalg.compression.CompressionDescriptor;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.jcublas.compression.CudaThreshold;
import org.nd4j.linalg.primitives.Pair;

public class CudaFlexibleThreshold
extends CudaThreshold {
    public CudaFlexibleThreshold() {
        this.threshold = 0.1f;
    }

    @Override
    public String getDescriptor() {
        return "FTHRESHOLD";
    }

    @Override
    public void configure(Object ... vars) {
        super.configure(vars);
    }

    @Override
    public DataBuffer compress(DataBuffer buffer) {
        INDArray temp = Nd4j.createArrayFromShapeBuffer((DataBuffer)buffer, (Pair)Nd4j.getShapeInfoProvider().createShapeInformation(new int[]{1, (int)buffer.length()}));
        double max = temp.amaxNumber().doubleValue();
        int cntAbs = temp.scan(Conditions.absGreaterThanOrEqual((Number)(max - max * (double)this.threshold))).intValue();
        long originalLength = buffer.length() * (long)Nd4j.sizeOfDataType((DataBuffer.Type)buffer.dataType());
        int compressedLength = cntAbs + 3;
        IntPointer pointer = new IntPointer((long)compressedLength);
        pointer.put(0L, cntAbs);
        pointer.put(1L, (int)buffer.length());
        pointer.put(2L, Float.floatToIntBits(this.threshold));
        CompressionDescriptor descriptor = new CompressionDescriptor();
        descriptor.setCompressedLength((long)(compressedLength * 4));
        descriptor.setOriginalLength(originalLength);
        descriptor.setOriginalElementSize((long)Nd4j.sizeOfDataType((DataBuffer.Type)buffer.dataType()));
        descriptor.setNumberOfElements(buffer.length());
        descriptor.setCompressionAlgorithm(this.getDescriptor());
        descriptor.setCompressionType(this.getCompressionType());
        CompressedDataBuffer cbuff = new CompressedDataBuffer((Pointer)pointer, descriptor);
        Nd4j.getNDArrayFactory().convertDataEx(CudaFlexibleThreshold.getBufferTypeEx((DataBuffer)buffer), buffer.addressPointer(), DataBuffer.TypeEx.FTHRESHOLD, (Pointer)pointer, buffer.length());
        Nd4j.getAffinityManager().tagLocation(buffer, AffinityManager.Location.HOST);
        return cbuff;
    }
}

