/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.conf.memory;

import java.beans.ConstructorProperties;
import java.util.Map;
import lombok.NonNull;
import org.deeplearning4j.nn.conf.CacheMode;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.memory.MemoryReport;
import org.deeplearning4j.nn.conf.memory.MemoryType;
import org.deeplearning4j.nn.conf.memory.MemoryUseMode;
import org.nd4j.linalg.api.buffer.DataBuffer;

public class LayerMemoryReport
extends MemoryReport {
    private String layerName;
    private Class<?> layerType;
    private InputType inputType;
    private InputType outputType;
    private long parameterSize;
    private long updaterStateSize;
    private long workingMemoryFixedInference;
    private long workingMemoryVariableInference;
    private Map<CacheMode, Long> workingMemoryFixedTrain;
    private Map<CacheMode, Long> workingMemoryVariableTrain;
    Map<CacheMode, Long> cacheModeMemFixed;
    Map<CacheMode, Long> cacheModeMemVariablePerEx;

    protected LayerMemoryReport(Builder b) {
        this.layerName = b.layerName;
        this.layerType = b.layerType;
        this.inputType = b.inputType;
        this.outputType = b.outputType;
        this.parameterSize = b.parameterSize;
        this.updaterStateSize = b.updaterStateSize;
        this.workingMemoryFixedInference = b.workingMemoryFixedInference;
        this.workingMemoryVariableInference = b.workingMemoryVariableInference;
        this.workingMemoryFixedTrain = b.workingMemoryFixedTrain;
        this.workingMemoryVariableTrain = b.workingMemoryVariableTrain;
        this.cacheModeMemFixed = b.cacheModeMemFixed;
        this.cacheModeMemVariablePerEx = b.cacheModeMemVariablePerEx;
    }

    @Override
    public Class<?> getReportClass() {
        return this.layerType;
    }

    @Override
    public String getName() {
        return this.layerName;
    }

    @Override
    public long getTotalMemoryBytes(int minibatchSize, @NonNull MemoryUseMode memoryUseMode, @NonNull CacheMode cacheMode, @NonNull DataBuffer.Type dataType) {
        if (memoryUseMode == null) {
            throw new NullPointerException("memoryUseMode");
        }
        if (cacheMode == null) {
            throw new NullPointerException("cacheMode");
        }
        if (dataType == null) {
            throw new NullPointerException("dataType");
        }
        long total = 0L;
        for (MemoryType mt : MemoryType.values()) {
            total += this.getMemoryBytes(mt, minibatchSize, memoryUseMode, cacheMode, dataType);
        }
        return total;
    }

    @Override
    public long getMemoryBytes(MemoryType memoryType, int minibatchSize, MemoryUseMode memoryUseMode, CacheMode cacheMode, DataBuffer.Type dataType) {
        int bytesPerElement = this.getBytesPerElement(dataType);
        switch (memoryType) {
            case PARAMETERS: {
                return this.parameterSize * (long)bytesPerElement;
            }
            case PARAMATER_GRADIENTS: {
                if (memoryUseMode == MemoryUseMode.INFERENCE) {
                    return 0L;
                }
                return this.parameterSize * (long)bytesPerElement;
            }
            case ACTIVATIONS: {
                return minibatchSize * this.outputType.arrayElementsPerExample() * bytesPerElement;
            }
            case ACTIVATION_GRADIENTS: {
                if (memoryUseMode == MemoryUseMode.INFERENCE) {
                    return 0L;
                }
                return minibatchSize * this.inputType.arrayElementsPerExample() * bytesPerElement;
            }
            case UPDATER_STATE: {
                if (memoryUseMode == MemoryUseMode.INFERENCE) {
                    return 0L;
                }
                return this.updaterStateSize * (long)bytesPerElement;
            }
            case WORKING_MEMORY_FIXED: {
                if (memoryUseMode == MemoryUseMode.INFERENCE) {
                    return this.workingMemoryFixedInference * (long)bytesPerElement;
                }
                return this.workingMemoryFixedTrain.get((Object)cacheMode) * (long)bytesPerElement;
            }
            case WORKING_MEMORY_VARIABLE: {
                if (memoryUseMode == MemoryUseMode.INFERENCE) {
                    return this.workingMemoryVariableInference * (long)bytesPerElement;
                }
                return (long)minibatchSize * this.workingMemoryVariableTrain.get((Object)cacheMode) * (long)bytesPerElement;
            }
            case CACHED_MEMORY_FIXED: {
                if (memoryUseMode == MemoryUseMode.INFERENCE) {
                    return 0L;
                }
                return this.cacheModeMemFixed.get((Object)cacheMode) * (long)bytesPerElement;
            }
            case CACHED_MEMORY_VARIABLE: {
                if (memoryUseMode == MemoryUseMode.INFERENCE) {
                    return 0L;
                }
                return (long)minibatchSize * this.cacheModeMemVariablePerEx.get((Object)cacheMode) * (long)bytesPerElement;
            }
        }
        throw new IllegalStateException("Unknown memory type: " + (Object)((Object)memoryType));
    }

    @Override
    public String toString() {
        return "LayerMemoryReport(layerName=" + this.layerName + ",layerType=" + this.layerType.getSimpleName() + ")";
    }

    public String getLayerName() {
        return this.layerName;
    }

    public Class<?> getLayerType() {
        return this.layerType;
    }

    public InputType getInputType() {
        return this.inputType;
    }

    public InputType getOutputType() {
        return this.outputType;
    }

    public long getParameterSize() {
        return this.parameterSize;
    }

    public long getUpdaterStateSize() {
        return this.updaterStateSize;
    }

    public long getWorkingMemoryFixedInference() {
        return this.workingMemoryFixedInference;
    }

    public long getWorkingMemoryVariableInference() {
        return this.workingMemoryVariableInference;
    }

    public Map<CacheMode, Long> getWorkingMemoryFixedTrain() {
        return this.workingMemoryFixedTrain;
    }

    public Map<CacheMode, Long> getWorkingMemoryVariableTrain() {
        return this.workingMemoryVariableTrain;
    }

    public Map<CacheMode, Long> getCacheModeMemFixed() {
        return this.cacheModeMemFixed;
    }

    public Map<CacheMode, Long> getCacheModeMemVariablePerEx() {
        return this.cacheModeMemVariablePerEx;
    }

    public void setLayerName(String layerName) {
        this.layerName = layerName;
    }

    public void setLayerType(Class<?> layerType) {
        this.layerType = layerType;
    }

    public void setInputType(InputType inputType) {
        this.inputType = inputType;
    }

    public void setOutputType(InputType outputType) {
        this.outputType = outputType;
    }

    public void setParameterSize(long parameterSize) {
        this.parameterSize = parameterSize;
    }

    public void setUpdaterStateSize(long updaterStateSize) {
        this.updaterStateSize = updaterStateSize;
    }

    public void setWorkingMemoryFixedInference(long workingMemoryFixedInference) {
        this.workingMemoryFixedInference = workingMemoryFixedInference;
    }

    public void setWorkingMemoryVariableInference(long workingMemoryVariableInference) {
        this.workingMemoryVariableInference = workingMemoryVariableInference;
    }

    public void setWorkingMemoryFixedTrain(Map<CacheMode, Long> workingMemoryFixedTrain) {
        this.workingMemoryFixedTrain = workingMemoryFixedTrain;
    }

    public void setWorkingMemoryVariableTrain(Map<CacheMode, Long> workingMemoryVariableTrain) {
        this.workingMemoryVariableTrain = workingMemoryVariableTrain;
    }

    public void setCacheModeMemFixed(Map<CacheMode, Long> cacheModeMemFixed) {
        this.cacheModeMemFixed = cacheModeMemFixed;
    }

    public void setCacheModeMemVariablePerEx(Map<CacheMode, Long> cacheModeMemVariablePerEx) {
        this.cacheModeMemVariablePerEx = cacheModeMemVariablePerEx;
    }

    @Override
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof LayerMemoryReport)) {
            return false;
        }
        LayerMemoryReport other = (LayerMemoryReport)o;
        if (!other.canEqual(this)) {
            return false;
        }
        String this$layerName = this.getLayerName();
        String other$layerName = other.getLayerName();
        if (this$layerName == null ? other$layerName != null : !this$layerName.equals(other$layerName)) {
            return false;
        }
        Class<?> this$layerType = this.getLayerType();
        Class<?> other$layerType = other.getLayerType();
        if (this$layerType == null ? other$layerType != null : !this$layerType.equals(other$layerType)) {
            return false;
        }
        InputType this$inputType = this.getInputType();
        InputType other$inputType = other.getInputType();
        if (this$inputType == null ? other$inputType != null : !this$inputType.equals(other$inputType)) {
            return false;
        }
        InputType this$outputType = this.getOutputType();
        InputType other$outputType = other.getOutputType();
        if (this$outputType == null ? other$outputType != null : !this$outputType.equals(other$outputType)) {
            return false;
        }
        if (this.getParameterSize() != other.getParameterSize()) {
            return false;
        }
        if (this.getUpdaterStateSize() != other.getUpdaterStateSize()) {
            return false;
        }
        if (this.getWorkingMemoryFixedInference() != other.getWorkingMemoryFixedInference()) {
            return false;
        }
        if (this.getWorkingMemoryVariableInference() != other.getWorkingMemoryVariableInference()) {
            return false;
        }
        Map<CacheMode, Long> this$workingMemoryFixedTrain = this.getWorkingMemoryFixedTrain();
        Map<CacheMode, Long> other$workingMemoryFixedTrain = other.getWorkingMemoryFixedTrain();
        if (this$workingMemoryFixedTrain == null ? other$workingMemoryFixedTrain != null : !((Object)this$workingMemoryFixedTrain).equals(other$workingMemoryFixedTrain)) {
            return false;
        }
        Map<CacheMode, Long> this$workingMemoryVariableTrain = this.getWorkingMemoryVariableTrain();
        Map<CacheMode, Long> other$workingMemoryVariableTrain = other.getWorkingMemoryVariableTrain();
        if (this$workingMemoryVariableTrain == null ? other$workingMemoryVariableTrain != null : !((Object)this$workingMemoryVariableTrain).equals(other$workingMemoryVariableTrain)) {
            return false;
        }
        Map<CacheMode, Long> this$cacheModeMemFixed = this.getCacheModeMemFixed();
        Map<CacheMode, Long> other$cacheModeMemFixed = other.getCacheModeMemFixed();
        if (this$cacheModeMemFixed == null ? other$cacheModeMemFixed != null : !((Object)this$cacheModeMemFixed).equals(other$cacheModeMemFixed)) {
            return false;
        }
        Map<CacheMode, Long> this$cacheModeMemVariablePerEx = this.getCacheModeMemVariablePerEx();
        Map<CacheMode, Long> other$cacheModeMemVariablePerEx = other.getCacheModeMemVariablePerEx();
        return !(this$cacheModeMemVariablePerEx == null ? other$cacheModeMemVariablePerEx != null : !((Object)this$cacheModeMemVariablePerEx).equals(other$cacheModeMemVariablePerEx));
    }

    @Override
    protected boolean canEqual(Object other) {
        return other instanceof LayerMemoryReport;
    }

    @Override
    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        String $layerName = this.getLayerName();
        result = result * 59 + ($layerName == null ? 43 : $layerName.hashCode());
        Class<?> $layerType = this.getLayerType();
        result = result * 59 + ($layerType == null ? 43 : $layerType.hashCode());
        InputType $inputType = this.getInputType();
        result = result * 59 + ($inputType == null ? 43 : $inputType.hashCode());
        InputType $outputType = this.getOutputType();
        result = result * 59 + ($outputType == null ? 43 : $outputType.hashCode());
        long $parameterSize = this.getParameterSize();
        result = result * 59 + (int)($parameterSize >>> 32 ^ $parameterSize);
        long $updaterStateSize = this.getUpdaterStateSize();
        result = result * 59 + (int)($updaterStateSize >>> 32 ^ $updaterStateSize);
        long $workingMemoryFixedInference = this.getWorkingMemoryFixedInference();
        result = result * 59 + (int)($workingMemoryFixedInference >>> 32 ^ $workingMemoryFixedInference);
        long $workingMemoryVariableInference = this.getWorkingMemoryVariableInference();
        result = result * 59 + (int)($workingMemoryVariableInference >>> 32 ^ $workingMemoryVariableInference);
        Map<CacheMode, Long> $workingMemoryFixedTrain = this.getWorkingMemoryFixedTrain();
        result = result * 59 + ($workingMemoryFixedTrain == null ? 43 : ((Object)$workingMemoryFixedTrain).hashCode());
        Map<CacheMode, Long> $workingMemoryVariableTrain = this.getWorkingMemoryVariableTrain();
        result = result * 59 + ($workingMemoryVariableTrain == null ? 43 : ((Object)$workingMemoryVariableTrain).hashCode());
        Map<CacheMode, Long> $cacheModeMemFixed = this.getCacheModeMemFixed();
        result = result * 59 + ($cacheModeMemFixed == null ? 43 : ((Object)$cacheModeMemFixed).hashCode());
        Map<CacheMode, Long> $cacheModeMemVariablePerEx = this.getCacheModeMemVariablePerEx();
        result = result * 59 + ($cacheModeMemVariablePerEx == null ? 43 : ((Object)$cacheModeMemVariablePerEx).hashCode());
        return result;
    }

    @ConstructorProperties(value={"layerName", "layerType", "inputType", "outputType", "parameterSize", "updaterStateSize", "workingMemoryFixedInference", "workingMemoryVariableInference", "workingMemoryFixedTrain", "workingMemoryVariableTrain", "cacheModeMemFixed", "cacheModeMemVariablePerEx"})
    public LayerMemoryReport(String layerName, Class<?> layerType, InputType inputType, InputType outputType, long parameterSize, long updaterStateSize, long workingMemoryFixedInference, long workingMemoryVariableInference, Map<CacheMode, Long> workingMemoryFixedTrain, Map<CacheMode, Long> workingMemoryVariableTrain, Map<CacheMode, Long> cacheModeMemFixed, Map<CacheMode, Long> cacheModeMemVariablePerEx) {
        this.layerName = layerName;
        this.layerType = layerType;
        this.inputType = inputType;
        this.outputType = outputType;
        this.parameterSize = parameterSize;
        this.updaterStateSize = updaterStateSize;
        this.workingMemoryFixedInference = workingMemoryFixedInference;
        this.workingMemoryVariableInference = workingMemoryVariableInference;
        this.workingMemoryFixedTrain = workingMemoryFixedTrain;
        this.workingMemoryVariableTrain = workingMemoryVariableTrain;
        this.cacheModeMemFixed = cacheModeMemFixed;
        this.cacheModeMemVariablePerEx = cacheModeMemVariablePerEx;
    }

    public LayerMemoryReport() {
    }

    public static class Builder {
        private String layerName;
        private Class<?> layerType;
        private InputType inputType;
        private InputType outputType;
        private long parameterSize;
        private long updaterStateSize;
        private long workingMemoryFixedInference;
        private long workingMemoryVariableInference;
        private Map<CacheMode, Long> workingMemoryFixedTrain;
        private Map<CacheMode, Long> workingMemoryVariableTrain;
        Map<CacheMode, Long> cacheModeMemFixed;
        Map<CacheMode, Long> cacheModeMemVariablePerEx;

        public Builder(String layerName, Class<?> layerType, InputType inputType, InputType outputType) {
            this.layerName = layerName;
            this.layerType = layerType;
            this.inputType = inputType;
            this.outputType = outputType;
        }

        public Builder standardMemory(long parameterSize, long updaterStateSize) {
            this.parameterSize = parameterSize;
            this.updaterStateSize = updaterStateSize;
            return this;
        }

        public Builder workingMemory(long fixedInference, long variableInferencePerEx, long fixedTrain, long variableTrainPerEx) {
            return this.workingMemory(fixedInference, variableInferencePerEx, MemoryReport.cacheModeMapFor(fixedTrain), MemoryReport.cacheModeMapFor(variableTrainPerEx));
        }

        public Builder workingMemory(long fixedInference, long variableInferencePerEx, Map<CacheMode, Long> fixedTrain, Map<CacheMode, Long> variableTrainPerEx) {
            this.workingMemoryFixedInference = fixedInference;
            this.workingMemoryVariableInference = variableInferencePerEx;
            this.workingMemoryFixedTrain = fixedTrain;
            this.workingMemoryVariableTrain = variableTrainPerEx;
            return this;
        }

        public Builder cacheMemory(long cacheModeMemoryFixed, long cacheModeMemoryVariablePerEx) {
            return this.cacheMemory(MemoryReport.cacheModeMapFor(cacheModeMemoryFixed), MemoryReport.cacheModeMapFor(cacheModeMemoryVariablePerEx));
        }

        public Builder cacheMemory(Map<CacheMode, Long> cacheModeMemoryFixed, Map<CacheMode, Long> cacheModeMemoryVariablePerEx) {
            this.cacheModeMemFixed = cacheModeMemoryFixed;
            this.cacheModeMemVariablePerEx = cacheModeMemoryVariablePerEx;
            return this;
        }

        public LayerMemoryReport build() {
            return new LayerMemoryReport(this);
        }
    }
}

