/*
 * Decompiled with CFR 0.152.
 */
package adams.ml.dl4j;

import java.util.Random;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.SplitTestAndTrain;
import org.nd4j.linalg.factory.Nd4j;

public class DataSetHelper {
    public static boolean equalStructure(DataSet data1, DataSet data2) {
        return DataSetHelper.equalStructureMsg(data1, data2) == null;
    }

    public static String equalStructureMsg(DataSet data1, DataSet data2) {
        if (data1.numInputs() != data2.numInputs()) {
            return "Number of inputs differ: " + data1.numInputs() + " != " + data2.numInputs();
        }
        if (data1.numOutcomes() != data2.numOutcomes()) {
            return "Number of outcomes differ: " + data1.numOutcomes() + " != " + data2.numOutcomes();
        }
        if (data1.getLabels() == null && data2.getLabels() != null) {
            return "First dataset has no labels, but second does!";
        }
        if (data1.getLabels() != null && data2.getLabels() == null) {
            return "First dataset has labels, but second doesn't!";
        }
        if (data1.getLabels() != null && data1.getLabelNamesList().size() != data2.getLabelNamesList().size()) {
            return "Number of labels differ: " + data1.getLabelNamesList().size() + " != " + data2.getLabelNamesList().size();
        }
        return null;
    }

    public static DataSet randomize(DataSet data, long seed, boolean copy) {
        if (copy) {
            data = data.copy();
        }
        Nd4j.shuffle((INDArray)data.getFeatureMatrix(), (Random)new Random(seed), (int[])new int[]{1});
        if (data.getLabels() != null) {
            Nd4j.shuffle((INDArray)data.getLabels(), (Random)new Random(seed), (int[])new int[]{1});
        }
        return data;
    }

    public static DataSet[] split(DataSet data, double perc) {
        return DataSetHelper.split(data, perc, null);
    }

    public static DataSet[] split(DataSet data, double perc, Long seed) {
        if (seed != null) {
            data = DataSetHelper.randomize(data, seed, true);
        }
        SplitTestAndTrain split = data.splitTestAndTrain(perc);
        return new DataSet[]{split.getTrain(), split.getTest()};
    }
}

