/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.util;

import java.io.BufferedOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.HashMap;
import java.util.Map;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.Updater;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class NetSaverLoaderUtils {
    private static final Logger log = LoggerFactory.getLogger(NetSaverLoaderUtils.class);

    private NetSaverLoaderUtils() {
    }

    public static void saveNetworkAndParameters(MultiLayerNetwork net, String basePath) {
        String confPath = FilenameUtils.concat((String)basePath, (String)(net.toString() + "-conf.json"));
        String paramPath = FilenameUtils.concat((String)basePath, (String)(net.toString() + ".bin"));
        log.info("Saving model and parameters to {} and {} ...", (Object)confPath, (Object)paramPath);
        try (DataOutputStream dos = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(paramPath)));){
            Nd4j.write((INDArray)net.params(), (DataOutputStream)dos);
            dos.flush();
            FileUtils.write((File)new File(confPath), (CharSequence)net.getLayerWiseConfigurations().toJson());
        }
        catch (IOException e) {
            e.printStackTrace();
        }
    }

    public static MultiLayerNetwork loadNetworkAndParameters(String confPath, String paramPath) {
        log.info("Loading saved model and parameters...");
        MultiLayerNetwork savedNetwork = null;
        try {
            MultiLayerConfiguration confFromJson = MultiLayerConfiguration.fromJson(confPath);
            DataInputStream dis = new DataInputStream(new FileInputStream(paramPath));
            INDArray newParams = Nd4j.read((DataInputStream)dis);
            dis.close();
            savedNetwork = new MultiLayerNetwork(confFromJson);
            savedNetwork.init();
            savedNetwork.setParams(newParams);
        }
        catch (IOException e) {
            e.printStackTrace();
        }
        return savedNetwork;
    }

    public static void saveUpdators(MultiLayerNetwork net, String basePath) {
        String paramPath = FilenameUtils.concat((String)basePath, (String)(net.toString() + "updators.bin"));
        try (ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(new File(paramPath)));){
            oos.writeObject(net.getUpdater());
        }
        catch (IOException e) {
            e.printStackTrace();
        }
    }

    public static Updater loadUpdators(String updatorPath) {
        Updater updater = null;
        try (ObjectInputStream oos = new ObjectInputStream(new FileInputStream(new File(updatorPath)));){
            updater = (Updater)oos.readObject();
        }
        catch (IOException e) {
            e.printStackTrace();
        }
        catch (ClassNotFoundException e) {
            e.printStackTrace();
        }
        return updater;
    }

    public static void saveLayerParameters(INDArray param, String paramPath) {
        log.info("Saving parameters to {} ...", (Object)paramPath);
        try (DataOutputStream dos = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(paramPath)));){
            Nd4j.write((INDArray)param, (DataOutputStream)dos);
            dos.flush();
        }
        catch (IOException e) {
            e.printStackTrace();
        }
    }

    public static Layer loadLayerParameters(Layer layer, String paramPath) {
        String name = layer.conf().getLayer().getLayerName();
        log.info("Loading saved parameters for layer {} ...", (Object)name);
        try {
            DataInputStream dis = new DataInputStream(new FileInputStream(paramPath));
            INDArray param = Nd4j.read((DataInputStream)dis);
            dis.close();
            layer.setParams(param);
        }
        catch (IOException e) {
            e.printStackTrace();
        }
        return layer;
    }

    public static void saveParameters(MultiLayerNetwork net, int[] layerIds, Map<Integer, String> paramPaths) {
        for (int layerId : layerIds) {
            Layer layer = net.getLayer(layerId);
            if (layer.paramTable().isEmpty()) continue;
            NetSaverLoaderUtils.saveLayerParameters(layer.params(), paramPaths.get(layerId));
        }
    }

    public static void saveParameters(MultiLayerNetwork net, String[] layerIds, Map<String, String> paramPaths) {
        for (String layerId : layerIds) {
            Layer layer = net.getLayer(layerId);
            if (layer.paramTable().isEmpty()) continue;
            NetSaverLoaderUtils.saveLayerParameters(layer.params(), paramPaths.get(layerId));
        }
    }

    public static MultiLayerNetwork loadParameters(MultiLayerNetwork net, int[] layerIds, Map<Integer, String> paramPaths) {
        for (int layerId : layerIds) {
            Layer layer = net.getLayer(layerId);
            NetSaverLoaderUtils.loadLayerParameters(layer, paramPaths.get(layerId));
        }
        return net;
    }

    public static MultiLayerNetwork loadParameters(MultiLayerNetwork net, String[] layerIds, Map<String, String> paramPaths) {
        for (String layerId : layerIds) {
            Layer layer = net.getLayer(layerId);
            NetSaverLoaderUtils.loadLayerParameters(layer, paramPaths.get(layerId));
        }
        return net;
    }

    public static Map<Integer, String> getIdParamPaths(String basePath, int[] layerIds) {
        HashMap<Integer, String> paramPaths = new HashMap<Integer, String>();
        for (int id : layerIds) {
            paramPaths.put(id, FilenameUtils.concat((String)basePath, (String)(id + ".bin")));
        }
        return paramPaths;
    }

    public static Map<String, String> getStringParamPaths(String basePath, String[] layerIds) {
        HashMap<String, String> paramPaths = new HashMap<String, String>();
        for (String name : layerIds) {
            paramPaths.put(name, FilenameUtils.concat((String)basePath, (String)(name + ".bin")));
        }
        return paramPaths;
    }

    public static String defineOutputDir(String networkType) {
        String outputPath;
        String tmpDir = System.getProperty("java.io.tmpdir");
        File dataDir = new File(tmpDir, outputPath = File.separator + networkType + File.separator + "output");
        if (!dataDir.getParentFile().exists()) {
            dataDir.mkdirs();
        }
        return dataDir.toString();
    }
}

