/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.dataset.api.preprocessor.stats;

import java.io.File;
import java.io.IOException;
import lombok.NonNull;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.DataSetUtil;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
import org.nd4j.linalg.dataset.api.preprocessor.stats.NormalizerStats;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DistributionStats
implements NormalizerStats {
    private static final Logger logger = LoggerFactory.getLogger(NormalizerStandardize.class);
    private final INDArray mean;
    private final INDArray std;

    public DistributionStats(@NonNull INDArray mean, @NonNull INDArray std) {
        if (mean == null) {
            throw new NullPointerException("mean");
        }
        if (std == null) {
            throw new NullPointerException("std");
        }
        Transforms.max(std, Nd4j.EPS_THRESHOLD, false);
        if (std.min(1) == Nd4j.scalar(Nd4j.EPS_THRESHOLD)) {
            logger.info("API_INFO: Std deviation found to be zero. Transform will round up to epsilon to avoid nans.");
        }
        this.mean = mean;
        this.std = std;
    }

    public static DistributionStats load(@NonNull File meanFile, @NonNull File stdFile) throws IOException {
        if (meanFile == null) {
            throw new NullPointerException("meanFile");
        }
        if (stdFile == null) {
            throw new NullPointerException("stdFile");
        }
        return new DistributionStats(Nd4j.readBinary(meanFile), Nd4j.readBinary(stdFile));
    }

    public void save(@NonNull File meanFile, @NonNull File stdFile) throws IOException {
        if (meanFile == null) {
            throw new NullPointerException("meanFile");
        }
        if (stdFile == null) {
            throw new NullPointerException("stdFile");
        }
        Nd4j.saveBinary(this.getMean(), meanFile);
        Nd4j.saveBinary(this.getStd(), stdFile);
    }

    public INDArray getMean() {
        return this.mean;
    }

    public INDArray getStd() {
        return this.std;
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof DistributionStats)) {
            return false;
        }
        DistributionStats other = (DistributionStats)o;
        if (!other.canEqual(this)) {
            return false;
        }
        INDArray this$mean = this.getMean();
        INDArray other$mean = other.getMean();
        if (this$mean == null ? other$mean != null : !this$mean.equals(other$mean)) {
            return false;
        }
        INDArray this$std = this.getStd();
        INDArray other$std = other.getStd();
        return !(this$std == null ? other$std != null : !this$std.equals(other$std));
    }

    protected boolean canEqual(Object other) {
        return other instanceof DistributionStats;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        INDArray $mean = this.getMean();
        result = result * 59 + ($mean == null ? 43 : $mean.hashCode());
        INDArray $std = this.getStd();
        result = result * 59 + ($std == null ? 43 : $std.hashCode());
        return result;
    }

    public static class Builder
    implements NormalizerStats.Builder<DistributionStats> {
        private int runningCount = 0;
        private INDArray runningMean;
        private INDArray runningVariance;

        public Builder addFeatures(@NonNull DataSet dataSet) {
            if (dataSet == null) {
                throw new NullPointerException("dataSet");
            }
            return this.add(dataSet.getFeatures(), dataSet.getFeaturesMaskArray());
        }

        public Builder addLabels(@NonNull DataSet dataSet) {
            if (dataSet == null) {
                throw new NullPointerException("dataSet");
            }
            return this.add(dataSet.getLabels(), dataSet.getLabelsMaskArray());
        }

        public Builder add(@NonNull INDArray data, INDArray mask) {
            if (data == null) {
                throw new NullPointerException("data");
            }
            if ((data = DataSetUtil.tailor2d(data, mask)) == null) {
                return this;
            }
            INDArray mean = data.mean(0);
            INDArray variance = data.var(false, 0);
            int count = data.size(0);
            if (this.runningMean == null) {
                this.runningMean = mean;
                this.runningVariance = variance;
                this.runningCount = count;
                if (data.size(0) == 1) {
                    this.runningMean = this.runningMean.dup();
                    this.runningVariance = this.runningVariance.dup();
                }
            } else {
                INDArray deltaSquared = Transforms.pow(mean.subRowVector(this.runningMean), 2);
                INDArray mB = variance.muli(count);
                this.runningVariance.muli(this.runningCount).addiRowVector(mB).addiRowVector(deltaSquared.muli(Float.valueOf((float)(this.runningCount * count) / (float)(this.runningCount + count)))).divi(this.runningCount + count);
                this.runningCount += count;
                INDArray xMinusMean = data.subRowVector(this.runningMean);
                this.runningMean.addi(xMinusMean.sum(0).divi(this.runningCount));
            }
            return this;
        }

        @Override
        public DistributionStats build() {
            if (this.runningMean == null) {
                throw new RuntimeException("No data was added, statistics cannot be determined");
            }
            return new DistributionStats(this.runningMean.dup(), Transforms.sqrt(this.runningVariance, true));
        }
    }
}

