package adams.ml.dl4j;

import java.util.Random;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.SplitTestAndTrain;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:adams/ml/dl4j/DataSetHelper.class */
public class DataSetHelper {
    public static boolean equalStructure(DataSet dataSet, DataSet dataSet2) {
        return equalStructureMsg(dataSet, dataSet2) == null;
    }

    public static String equalStructureMsg(DataSet dataSet, DataSet dataSet2) {
        if (dataSet.numInputs() != dataSet2.numInputs()) {
            return "Number of inputs differ: " + dataSet.numInputs() + " != " + dataSet2.numInputs();
        }
        if (dataSet.numOutcomes() != dataSet2.numOutcomes()) {
            return "Number of outcomes differ: " + dataSet.numOutcomes() + " != " + dataSet2.numOutcomes();
        }
        if (dataSet.getLabels() == null && dataSet2.getLabels() != null) {
            return "First dataset has no labels, but second does!";
        }
        if (dataSet.getLabels() != null && dataSet2.getLabels() == null) {
            return "First dataset has labels, but second doesn't!";
        }
        if (dataSet.getLabels() == null || dataSet.getLabelNamesList().size() == dataSet2.getLabelNamesList().size()) {
            return null;
        }
        return "Number of labels differ: " + dataSet.getLabelNamesList().size() + " != " + dataSet2.getLabelNamesList().size();
    }

    public static DataSet randomize(DataSet dataSet, long j, boolean z) {
        if (z) {
            dataSet = dataSet.copy();
        }
        Nd4j.shuffle(dataSet.getFeatureMatrix(), new Random(j), new int[]{1});
        if (dataSet.getLabels() != null) {
            Nd4j.shuffle(dataSet.getLabels(), new Random(j), new int[]{1});
        }
        return dataSet;
    }

    public static DataSet[] split(DataSet dataSet, double d) {
        return split(dataSet, d, null);
    }

    public static DataSet[] split(DataSet dataSet, double d, Long l) {
        if (l != null) {
            dataSet = randomize(dataSet, l.longValue(), true);
        }
        SplitTestAndTrain splitTestAndTrain = dataSet.splitTestAndTrain(d);
        return new DataSet[]{splitTestAndTrain.getTrain(), splitTestAndTrain.getTest()};
    }
}
