package adams.ml.dl4j.model;

import adams.data.InPlaceProcessing;
import adams.flow.control.StorageName;
import adams.flow.control.StorageUser;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.util.ModelSerializer;

/* loaded from: input_file:adams/ml/dl4j/model/FromStorage.class */
public class FromStorage extends AbstractModelConfigurator implements StorageUser, InPlaceProcessing {
    private static final long serialVersionUID = -5856765502127602083L;
    protected String m_Cache;
    protected StorageName m_StorageName;
    protected boolean m_NoCopy;

    public String globalInfo() {
        return "Retrieves the model simply from storage and forwards it.";
    }

    public void defineOptions() {
        super.defineOptions();
        this.m_OptionManager.add("cache", "cache", "");
        this.m_OptionManager.add("storage-name", "storageName", new StorageName());
        this.m_OptionManager.add("no-copy", "noCopy", false);
    }

    public void setCache(String str) {
        this.m_Cache = str;
        reset();
    }

    public String getCache() {
        return this.m_Cache;
    }

    public String cacheTipText() {
        return "The name of the cache to retrieve the value from; uses the regular storage if left empty.";
    }

    public void setStorageName(StorageName storageName) {
        this.m_StorageName = storageName;
        reset();
    }

    public StorageName getStorageName() {
        return this.m_StorageName;
    }

    public String storageNameTipText() {
        return "The name of the stored value to retrieve.";
    }

    public void setNoCopy(boolean z) {
        this.m_NoCopy = z;
        reset();
    }

    public boolean getNoCopy() {
        return this.m_NoCopy;
    }

    public String noCopyTipText() {
        return "If enabled, no copy of the model is created before returning it.";
    }

    public boolean isUsingStorage() {
        return true;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // adams.ml.dl4j.model.AbstractModelConfigurator
    public String check() {
        String check = super.check();
        if (check == null && !this.m_FlowContext.getStorageHandler().getStorage().has(this.m_StorageName)) {
            check = "Model not available from storage: " + this.m_StorageName;
        }
        return check;
    }

    @Override // adams.ml.dl4j.model.AbstractModelConfigurator
    protected Model doConfigureModel(int i, int i2) {
        Model model = this.m_Cache.length() == 0 ? (Model) this.m_FlowContext.getStorageHandler().getStorage().get(this.m_StorageName) : (Model) this.m_FlowContext.getStorageHandler().getStorage().get(this.m_Cache, this.m_StorageName);
        if (!this.m_NoCopy) {
            ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
            try {
                if ((model instanceof MultiLayerNetwork) && !((MultiLayerNetwork) model).isInitCalled()) {
                    ((MultiLayerNetwork) model).init();
                }
                ModelSerializer.writeModel(model, byteArrayOutputStream, true);
                ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(byteArrayOutputStream.toByteArray());
                if (model instanceof MultiLayerNetwork) {
                    model = ModelSerializer.restoreMultiLayerNetwork(byteArrayInputStream);
                } else {
                    if (!(model instanceof ComputationGraph)) {
                        throw new IllegalStateException("Unhandled model type: " + model.getClass().getName());
                    }
                    model = ModelSerializer.restoreComputationGraph(byteArrayInputStream);
                }
            } catch (Exception e) {
                throw new IllegalStateException("Failed to create copy of model!", e);
            }
        }
        return model;
    }
}
