/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.dataset.api.preprocessor;

import java.io.File;
import java.io.IOException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class NormalizerStandardize
implements DataNormalization {
    private static Logger logger = LoggerFactory.getLogger(NormalizerStandardize.class);
    private INDArray mean;
    private INDArray std;
    private int runningTotal = 0;
    private int batchCount = 0;

    @Override
    public void fit(DataSet dataSet) {
        this.mean = dataSet.getFeatureMatrix().mean(0);
        this.std = dataSet.getFeatureMatrix().std(0);
        this.std.addi(Nd4j.scalar(Nd4j.EPS_THRESHOLD));
        if (this.std.min(1) == Nd4j.scalar(Nd4j.EPS_THRESHOLD)) {
            logger.info("API_INFO: Std deviation found to be zero. Transform will round upto epsilon to avoid nans.");
        }
    }

    @Override
    public void fit(DataSetIterator iterator) {
        while (iterator.hasNext()) {
            DataSet next = (DataSet)iterator.next();
            this.runningTotal += next.numExamples();
            this.batchCount = next.getFeatures().size(0);
            if (this.mean == null) {
                this.mean = next.getFeatureMatrix().mean(0);
                this.std = this.batchCount == 1 ? Nd4j.zeros(this.mean.shape()) : Transforms.pow(next.getFeatureMatrix().std(0), 2);
                this.std.muli(this.batchCount);
                continue;
            }
            INDArray xMinusMean = next.getFeatureMatrix().subRowVector(this.mean);
            INDArray newMean = this.mean.add(xMinusMean.sum(0).divi(this.runningTotal));
            INDArray meanB = next.getFeatureMatrix().mean(0);
            INDArray deltaSq = Transforms.pow(meanB.subRowVector(this.mean), 2);
            INDArray deltaSqScaled = deltaSq.mul(Float.valueOf(((float)this.runningTotal - (float)this.batchCount) * (float)this.batchCount / (float)this.runningTotal));
            INDArray mtwoB = Transforms.pow(next.getFeatureMatrix().std(0), 2);
            mtwoB.muli(this.batchCount);
            this.std = this.std.add(mtwoB);
            this.std = this.std.add(deltaSqScaled);
            this.mean = newMean;
        }
        this.std.divi(this.runningTotal);
        this.std = Transforms.sqrt(this.std);
        this.std.addi(Nd4j.scalar(Nd4j.EPS_THRESHOLD));
        if (this.std.min(1) == Nd4j.scalar(Nd4j.EPS_THRESHOLD)) {
            logger.info("API_INFO: Std deviation found to be zero. Transform will round upto epsilon to avoid nans.");
        }
        iterator.reset();
    }

    @Override
    public void preProcess(DataSet toPreProcess) {
        if (this.mean == null || this.std == null) {
            throw new RuntimeException("API_USE_ERROR: Preprocessors have to be explicitly fit before use. Usage: .fit(dataset) or .fit(datasetiterator)");
        }
        toPreProcess.getFeatures().subiRowVector(this.mean);
        toPreProcess.getFeatures().diviRowVector(this.std);
    }

    @Override
    public void transform(DataSet toPreProcess) {
        this.preProcess(toPreProcess);
    }

    @Override
    public void transform(DataSetIterator toPreProcessIter) {
        while (toPreProcessIter.hasNext()) {
            this.preProcess((DataSet)toPreProcessIter.next());
        }
        toPreProcessIter.reset();
    }

    public void revertPreProcess(DataSet toPreProcess) {
        if (this.mean == null || this.std == null) {
            throw new RuntimeException("API_USE_ERROR: Preprocessors have to be explicitly fit before use. Usage: .fit(dataset) or .fit(datasetiterator)");
        }
        toPreProcess.getFeatures().muliRowVector(this.std);
        toPreProcess.getFeatures().addiRowVector(this.mean);
    }

    public void revert(DataSet toPreProcess) {
        this.revertPreProcess(toPreProcess);
    }

    public void revert(DataSetIterator toPreProcessIter) {
        while (toPreProcessIter.hasNext()) {
            this.revertPreProcess((DataSet)toPreProcessIter.next());
        }
        toPreProcessIter.reset();
    }

    public INDArray getMean() {
        if (this.mean == null) {
            throw new RuntimeException("API_USE_ERROR: Preprocessors have to be explicitly fit before use. Usage: .fit(dataset) or .fit(datasetiterator)");
        }
        return this.mean;
    }

    public INDArray getStd() {
        if (this.std == null) {
            throw new RuntimeException("API_USE_ERROR: Preprocessors have to be explicitly fit before use. Usage: .fit(dataset) or .fit(datasetiterator)");
        }
        return this.std;
    }

    @Override
    public void load(File ... statistics) throws IOException {
        this.mean = Nd4j.readBinary(statistics[0]);
        this.std = Nd4j.readBinary(statistics[1]);
    }

    @Override
    public void save(File ... statistics) throws IOException {
        Nd4j.saveBinary(this.mean, statistics[0]);
        Nd4j.saveBinary(this.std, statistics[1]);
    }
}

