/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.dataset.api.preprocessor;

import java.io.File;
import java.io.IOException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastAddOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastDivOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastSubOp;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class NormalizerStandardize
implements DataNormalization {
    private static Logger logger = LoggerFactory.getLogger(NormalizerStandardize.class);
    private int runningTotal;
    private int labelRunningTotal = 0;
    private int batchCount;
    private int labelbatchCount = 0;
    private int featureRank = 2;
    private INDArray featureMeanStd;
    private INDArray labelMeanStd;
    private INDArray featureMean;
    private INDArray featureStd;
    private INDArray labelMean;
    private INDArray labelStd;
    private boolean fitLabels = false;

    private INDArray fit(INDArray theArray) {
        INDArray theMean = theArray.mean(0);
        INDArray theStd = theArray.std(0);
        theStd.addi(Nd4j.scalar(Nd4j.EPS_THRESHOLD));
        if (theStd.min(1) == Nd4j.scalar(Nd4j.EPS_THRESHOLD)) {
            logger.info("API_INFO: Std deviation found to be zero. Transform will round upto epsilon to avoid nans.");
        }
        return Nd4j.vstack(theMean, theStd).dup();
    }

    private void runnningFit(INDArray thenewArray, INDArray currentMeanStd, int batchCount, int runningTotal, boolean allDone) {
        if (!allDone) {
            INDArray currentMean = currentMeanStd.getRow(0);
            INDArray currentStd = currentMeanStd.getRow(1);
            INDArray xMinusMean = thenewArray.subRowVector(currentMean);
            INDArray newMean = currentMean.add(xMinusMean.sum(0).divi(runningTotal));
            INDArray meanB = thenewArray.mean(0);
            INDArray deltaSq = Transforms.pow(meanB.subRowVector(currentMean), 2);
            INDArray deltaSqScaled = deltaSq.mul(Float.valueOf(((float)runningTotal - (float)batchCount) * (float)batchCount / (float)runningTotal));
            INDArray mtwoB = Transforms.pow(thenewArray.std(0), 2);
            mtwoB.muli(batchCount);
            currentStd.addi(mtwoB);
            currentStd.addi(deltaSqScaled);
            currentMeanStd.putRow(0, newMean);
        } else {
            currentMeanStd.getRow(1).divi(runningTotal);
            currentMeanStd.putRow(1, Transforms.sqrt(currentMeanStd.getRow(1)));
            currentMeanStd.getRow(1).addi(Nd4j.scalar(Nd4j.EPS_THRESHOLD));
            if (currentMeanStd.getRow(0).min(1) == Nd4j.scalar(Nd4j.EPS_THRESHOLD)) {
                logger.info("API_INFO: Std deviation found to be zero. Transform will round upto epsilon to avoid nans.");
            }
        }
    }

    public void fitLabel(boolean fitLabels) {
        this.fitLabels = fitLabels;
    }

    @Override
    public void fit(DataSet dataSet) {
        this.featureRank = dataSet.getFeatures().rank();
        INDArray theFeatures = dataSet.getFeatures();
        if (this.featureRank == 3) {
            theFeatures = this.tailor3d2d(dataSet, true);
        }
        if (this.featureRank == 4) {
            theFeatures = this.tailor4d2d(dataSet, true);
        }
        this.featureMeanStd = this.fit(theFeatures);
        this.featureMean = this.featureMeanStd.getRow(0).dup();
        this.featureStd = this.featureMeanStd.getRow(1).dup();
        if (this.fitLabels) {
            INDArray theLabels = dataSet.getLabels();
            if (this.featureRank == 3) {
                theLabels = this.tailor3d2d(dataSet, false);
            }
            if (this.featureRank == 4) {
                theLabels = this.tailor4d2d(dataSet, false);
            }
            this.labelMeanStd = this.fit(theLabels);
            this.labelMean = this.labelMeanStd.getRow(0).dup();
            this.labelStd = this.labelMeanStd.getRow(1).dup();
        }
    }

    @Override
    public void fit(DataSetIterator iterator) {
        this.featureMeanStd = null;
        this.runningTotal = 0;
        this.labelRunningTotal = 0;
        while (iterator.hasNext()) {
            DataSet next = (DataSet)iterator.next();
            this.batchCount = next.getFeaturesMaskArray() != null ? next.getFeaturesMaskArray().sumNumber().intValue() : next.getFeatures().size(0);
            this.runningTotal += this.batchCount;
            this.labelbatchCount = next.getLabelsMaskArray() != null ? next.getLabelsMaskArray().sumNumber().intValue() : next.getFeatures().size(0);
            this.labelRunningTotal += this.batchCount;
            if (this.featureMeanStd == null) {
                this.fit(next);
                this.featureMeanStd.getRow(1).muli(this.batchCount);
                if (!this.fitLabels) continue;
                this.labelMeanStd.getRow(1).muli(this.batchCount);
                continue;
            }
            INDArray theFeatures = next.getFeatures();
            if (this.featureRank == 3) {
                theFeatures = this.tailor3d2d(next, true);
            }
            if (this.featureRank == 4) {
                theFeatures = this.tailor4d2d(next, true);
            }
            this.runnningFit(theFeatures, this.featureMeanStd, this.batchCount, this.runningTotal, false);
            if (!this.fitLabels) continue;
            INDArray theLabels = next.getLabels();
            if (this.featureRank == 3) {
                theLabels = this.tailor3d2d(next, false);
            }
            if (this.featureRank == 4) {
                theLabels = this.tailor4d2d(next, false);
            }
            this.runnningFit(theLabels, this.labelMeanStd, this.labelbatchCount, this.labelRunningTotal, false);
        }
        this.runnningFit(this.featureMeanStd, this.featureMeanStd, this.batchCount, this.runningTotal, true);
        this.featureMean = this.featureMeanStd.getRow(0).dup();
        this.featureStd = this.featureMeanStd.getRow(1).dup();
        if (this.fitLabels) {
            this.runnningFit(this.labelMeanStd, this.labelMeanStd, this.labelbatchCount, this.labelRunningTotal, true);
            this.labelMean = this.labelMeanStd.getRow(0).dup();
            this.labelStd = this.labelMeanStd.getRow(1).dup();
        }
        iterator.reset();
    }

    @Override
    public void preProcess(DataSet toPreProcess) {
        if (this.featureMean == null) {
            throw new RuntimeException("API_USE_ERROR: Preprocessors have to be explicitly fit before use. Usage: .fit(dataset) or .fit(datasetiterator)");
        }
        INDArray theFeatures = toPreProcess.getFeatures();
        INDArray theLabels = toPreProcess.getLabels();
        this.preProcess(theFeatures, true);
        if (this.fitLabels) {
            this.preProcess(theLabels, false);
        }
    }

    private void preProcess(INDArray theFeatures, boolean isFeatures) {
        INDArray std;
        INDArray mean = isFeatures ? this.featureMean : this.labelMean;
        INDArray iNDArray = std = isFeatures ? this.featureStd : this.labelStd;
        if (this.featureRank == 2) {
            theFeatures.subiRowVector(mean);
            theFeatures.diviRowVector(std);
        } else {
            Nd4j.getExecutioner().execAndReturn(new BroadcastSubOp(theFeatures, mean, theFeatures, 1));
            Nd4j.getExecutioner().execAndReturn(new BroadcastDivOp(theFeatures, std, theFeatures, 1));
        }
    }

    @Override
    public void transform(DataSet toPreProcess) {
        this.preProcess(toPreProcess);
    }

    public void transform(INDArray theFeatures) {
        this.transform(theFeatures, true);
    }

    public void transform(INDArray theArray, boolean isFeatures) {
        this.preProcess(theArray, isFeatures);
    }

    public void revert(DataSet toPreProcess) {
        if (this.featureMean == null || this.featureStd == null) {
            throw new RuntimeException("API_USE_ERROR: Preprocessors have to be explicitly fit before use. Usage: .fit(dataset) or .fit(datasetiterator)");
        }
        if (this.featureRank == 2) {
            toPreProcess.getFeatures().muliRowVector(this.featureStd);
            toPreProcess.getFeatures().addiRowVector(this.featureMean);
        } else {
            Nd4j.getExecutioner().execAndReturn(new BroadcastMulOp(toPreProcess.getFeatures(), this.featureStd, toPreProcess.getFeatures(), 1));
            Nd4j.getExecutioner().execAndReturn(new BroadcastAddOp(toPreProcess.getFeatures(), this.featureMean, toPreProcess.getFeatures(), 1));
        }
    }

    public void revert(DataSetIterator toPreProcessIter) {
        while (toPreProcessIter.hasNext()) {
            this.revert((DataSet)toPreProcessIter.next());
        }
        toPreProcessIter.reset();
    }

    public INDArray getMean() {
        if (this.featureMean == null) {
            throw new RuntimeException("API_USE_ERROR: Preprocessors have to be explicitly fit before use. Usage: .fit(dataset) or .fit(datasetiterator)");
        }
        return this.featureMean;
    }

    public INDArray getLabelMean() {
        if (this.featureMean == null) {
            throw new RuntimeException("API_USE_ERROR: Preprocessors have to be explicitly fit before use. Usage: .fit(dataset) or .fit(datasetiterator)");
        }
        return this.labelMean;
    }

    public INDArray getStd() {
        if (this.featureStd == null) {
            throw new RuntimeException("API_USE_ERROR: Preprocessors have to be explicitly fit before use. Usage: .fit(dataset) or .fit(datasetiterator)");
        }
        return this.featureStd;
    }

    public INDArray getLabelStd() {
        if (this.featureStd == null) {
            throw new RuntimeException("API_USE_ERROR: Preprocessors have to be explicitly fit before use. Usage: .fit(dataset) or .fit(datasetiterator)");
        }
        return this.labelStd;
    }

    @Override
    public void load(File ... statistics) throws IOException {
        this.featureMean = Nd4j.readBinary(statistics[0]);
        this.featureStd = Nd4j.readBinary(statistics[1]);
        if (this.fitLabels) {
            this.labelMean = Nd4j.readBinary(statistics[2]);
            this.labelStd = Nd4j.readBinary(statistics[3]);
        }
    }

    @Override
    public void save(File ... statistics) throws IOException {
        Nd4j.saveBinary(this.featureMean, statistics[0]);
        Nd4j.saveBinary(this.featureStd, statistics[1]);
        if (this.fitLabels) {
            Nd4j.saveBinary(this.labelMean, statistics[2]);
            Nd4j.saveBinary(this.labelStd, statistics[3]);
        }
    }

    private INDArray tailor3d2d(DataSet dataset, boolean areFeatures) {
        INDArray theArray = areFeatures ? dataset.getFeatures() : dataset.getLabels();
        INDArray theMask = areFeatures ? dataset.getFeaturesMaskArray() : dataset.getLabelsMaskArray();
        int instances = theArray.size(0);
        int features = theArray.size(1);
        int timesteps = theArray.size(2);
        boolean hasMasks = theMask != null;
        INDArray in2d = Nd4j.create(features, timesteps * instances);
        int tads = theArray.tensorssAlongDimension(2, 0);
        for (int i = 0; i < tads; ++i) {
            INDArray thisTAD = theArray.tensorAlongDimension(i, 2, 0);
            if (hasMasks) {
                thisTAD.muli(theMask);
            }
            in2d.putRow(i, Nd4j.toFlattened('c', thisTAD));
        }
        in2d = in2d.transpose();
        if (hasMasks) {
            INDArray columnMask = Nd4j.toFlattened('c', theMask).transpose();
            int actualSamples = columnMask.sumNumber().intValue();
            INDArray in2dMask = Nd4j.create(actualSamples, features);
            int j = 0;
            for (int i = 0; i < timesteps * instances; ++i) {
                if (columnMask.getInt(i, 0) == 0) continue;
                in2dMask.putRow(j, in2d.getRow(i));
                ++j;
            }
            return in2dMask;
        }
        return in2d;
    }

    private INDArray tailor4d2d(DataSet dataset, boolean areFeatures) {
        INDArray theArray = areFeatures ? dataset.getFeatures() : dataset.getLabels();
        int instances = theArray.size(0);
        int channels = theArray.size(1);
        int height = theArray.size(2);
        int width = theArray.size(3);
        INDArray in2d = Nd4j.create(channels, height * width * instances);
        int tads = theArray.tensorssAlongDimension(3, 2, 0);
        for (int i = 0; i < tads; ++i) {
            INDArray thisTAD = theArray.tensorAlongDimension(i, 3, 2, 0);
            in2d.putRow(i, Nd4j.toFlattened(thisTAD));
        }
        return in2d.transposei();
    }
}

