package adams.ml.dl4j;

import adams.core.ClassLocator;
import adams.core.SerializationHelper;
import java.io.File;
import java.lang.reflect.Constructor;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.layers.BaseOutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:adams/ml/dl4j/ModelSerialization.class */
public class ModelSerialization {
    public static void write(File file, Model model) throws Exception {
        if (model instanceof MultiLayerNetwork) {
            SerializationHelper.writeAll(file.getAbsolutePath(), new Object[]{model.getClass().getName(), ((MultiLayerNetwork) model).getLayerWiseConfigurations().toJson(), model.params()});
        } else {
            SerializationHelper.writeAll(file.getAbsolutePath(), new Object[]{model.getClass().getName(), model.conf().toJson(), model.params()});
        }
    }

    public static Model read(File file) throws Exception {
        Object[] readAll = SerializationHelper.readAll(file.getAbsolutePath());
        if (readAll.length != 3) {
            throw new IllegalStateException("Unexpected number of objects in serialized file (instead of 3): " + readAll.length);
        }
        Class<?> cls = Class.forName((String) readAll[0]);
        if (cls == MultiLayerNetwork.class) {
            MultiLayerNetwork multiLayerNetwork = new MultiLayerNetwork(MultiLayerConfiguration.fromJson((String) readAll[1]));
            multiLayerNetwork.init();
            multiLayerNetwork.setParams((INDArray) readAll[2]);
            return multiLayerNetwork;
        }
        if (!ClassLocator.isSubclass(BaseOutputLayer.class, cls)) {
            throw new IllegalStateException("Don't know how to re-instantiate: " + cls.getName());
        }
        Constructor<?> constructor = cls.getConstructor(NeuralNetConfiguration.class);
        if (constructor == null) {
            throw new IllegalStateException("Failed to find constructor in class " + cls.getName() + " that takes a " + NeuralNetConfiguration.class.getName() + " object!");
        }
        Model model = (Model) constructor.newInstance(NeuralNetConfiguration.fromJson((String) readAll[1]));
        model.setParams((INDArray) readAll[2]);
        return model;
    }
}
