/*
 * Decompiled with CFR 0.152.
 */
package adams.ml.cntk.modelapplier;

import adams.core.MessageCollection;
import adams.core.QuickInfoHelper;
import adams.core.QuickInfoSupporter;
import adams.core.Range;
import adams.core.Utils;
import adams.core.base.BaseString;
import adams.core.io.PlaceholderFile;
import adams.core.logging.LoggingLevel;
import adams.core.logging.LoggingSupporter;
import adams.core.option.AbstractOptionHandler;
import adams.core.option.OptionHandler;
import adams.flow.control.StorageName;
import adams.flow.core.AbstractModelLoader;
import adams.flow.core.Actor;
import adams.flow.core.CNTKModelLoader;
import adams.flow.core.CallableActorReference;
import adams.flow.core.ModelLoaderSupporter;
import adams.ml.cntk.CNTKPredictionWrapper;
import adams.ml.cntk.DeviceType;
import com.microsoft.CNTK.Function;

public abstract class AbstractModelApplier<I, O>
extends AbstractOptionHandler
implements ModelLoaderSupporter,
QuickInfoSupporter {
    private static final long serialVersionUID = 7541008225536782803L;
    protected DeviceType m_DeviceType;
    protected long m_GPUDeviceID;
    protected int m_NumClasses;
    protected CNTKPredictionWrapper m_Wrapper;
    protected CNTKModelLoader m_ModelLoader;

    public String automaticOrderInfo() {
        return this.m_ModelLoader.automaticOrderInfo();
    }

    public void defineOptions() {
        super.defineOptions();
        this.m_OptionManager.add("model-loading-type", "modelLoadingType", (Object)AbstractModelLoader.ModelLoadingType.AUTO);
        this.m_OptionManager.add("model-file", "modelFile", (Object)new PlaceholderFile());
        this.m_OptionManager.add("device-type", "deviceType", (Object)DeviceType.DEFAULT);
        this.m_OptionManager.add("gpu-device-id", "GPUDeviceID", (Object)0L);
        this.m_OptionManager.add("model-actor", "modelActor", (Object)new CallableActorReference());
        this.m_OptionManager.add("model-storage", "modelStorage", (Object)new StorageName());
        this.m_OptionManager.add("input", "inputs", (Object)this.getDefaultInputs());
        this.m_OptionManager.add("input-name", "inputNames", (Object)this.getDefaultInputNames());
        this.m_OptionManager.add("class-name", "className", (Object)this.getDefaultClassName());
        this.m_OptionManager.add("output-name", "outputName", (Object)this.getDefaultOutputName());
        this.m_OptionManager.add("num-classes", "numClasses", (Object)this.getDefaultNumClasses());
    }

    protected void initialize() {
        super.initialize();
        this.m_Wrapper = new CNTKPredictionWrapper();
        this.m_ModelLoader = new CNTKModelLoader();
    }

    protected void reset() {
        super.reset();
        this.m_Wrapper.reset();
        this.m_ModelLoader.reset();
    }

    public synchronized void setLoggingLevel(LoggingLevel value) {
        super.setLoggingLevel(value);
        this.m_Wrapper.setLoggingLevel(value);
        this.m_ModelLoader.setLoggingLevel(value);
    }

    public void setModelLoadingType(AbstractModelLoader.ModelLoadingType value) {
        this.m_ModelLoader.setModelLoadingType(value);
        this.reset();
    }

    public AbstractModelLoader.ModelLoadingType getModelLoadingType() {
        return this.m_ModelLoader.getModelLoadingType();
    }

    public String modelLoadingTypeTipText() {
        return this.m_ModelLoader.modelLoadingTypeTipText();
    }

    public void setModelFile(PlaceholderFile value) {
        this.m_ModelLoader.setModelFile(value);
        this.reset();
    }

    public PlaceholderFile getModelFile() {
        return this.m_ModelLoader.getModelFile();
    }

    public String modelFileTipText() {
        return this.m_ModelLoader.modelFileTipText();
    }

    public void setDeviceType(DeviceType value) {
        this.m_DeviceType = value;
        this.reset();
    }

    public DeviceType getDeviceType() {
        return this.m_DeviceType;
    }

    public String deviceTypeTipText() {
        return "The device type to use.";
    }

    public void setGPUDeviceID(long value) {
        this.m_GPUDeviceID = value;
        this.reset();
    }

    public long getGPUDeviceID() {
        return this.m_GPUDeviceID;
    }

    public String GPUDeviceIDTipText() {
        return "The GPU device ID.";
    }

    public void setModelActor(CallableActorReference value) {
        this.m_ModelLoader.setModelActor(value);
        this.reset();
    }

    public CallableActorReference getModelActor() {
        return this.m_ModelLoader.getModelActor();
    }

    public String modelActorTipText() {
        return this.m_ModelLoader.modelActorTipText();
    }

    public void setModelStorage(StorageName value) {
        this.m_ModelLoader.setModelStorage(value);
        this.reset();
    }

    public StorageName getModelStorage() {
        return this.m_ModelLoader.getModelStorage();
    }

    public String modelStorageTipText() {
        return this.m_ModelLoader.modelStorageTipText();
    }

    protected Range[] getDefaultInputs() {
        return new Range[0];
    }

    public void setInputs(Range[] value) {
        this.m_Wrapper.setInputs(value);
        this.reset();
    }

    public Range[] getInputs() {
        return this.m_Wrapper.getInputs();
    }

    public String inputsTipText() {
        return this.m_Wrapper.inputsTipText();
    }

    protected BaseString[] getDefaultInputNames() {
        return new BaseString[0];
    }

    public void setInputNames(BaseString[] value) {
        this.m_Wrapper.setInputNames(value);
        this.reset();
    }

    public BaseString[] getInputNames() {
        return this.m_Wrapper.getInputNames();
    }

    public String inputNamesTipText() {
        return this.m_Wrapper.inputNamesTipText();
    }

    protected String getDefaultClassName() {
        return "";
    }

    public void setClassName(String value) {
        this.m_Wrapper.setClassName(value);
        this.reset();
    }

    public String getClassName() {
        return this.m_Wrapper.getClassName();
    }

    public String classNameTipText() {
        return this.m_Wrapper.classNameTipText();
    }

    protected String getDefaultOutputName() {
        return "";
    }

    public void setOutputName(String value) {
        this.m_Wrapper.setOutputName(value);
        this.reset();
    }

    public String getOutputName() {
        return this.m_Wrapper.getOutputName();
    }

    public String outputNameTipText() {
        return this.m_Wrapper.outputNameTipText();
    }

    protected int getDefaultNumClasses() {
        return 1;
    }

    public void setNumClasses(int value) {
        this.m_NumClasses = value;
        this.reset();
    }

    public int getNumClasses() {
        return this.m_NumClasses;
    }

    public String numClassesTipText() {
        return "The number of classes (numeric class = 1, otherwise number of class labels).";
    }

    public String getQuickInfo() {
        String result;
        switch (this.getModelLoadingType()) {
            case AUTO: {
                result = "automatic";
                break;
            }
            case FILE: {
                result = QuickInfoHelper.toString((OptionHandler)this, (String)"modelFile", (Object)this.getModelFile(), (String)"file: ");
                break;
            }
            case SOURCE_ACTOR: {
                result = QuickInfoHelper.toString((OptionHandler)this, (String)"modelSource", (Object)this.getModelActor(), (String)"source: ");
                break;
            }
            case STORAGE: {
                result = QuickInfoHelper.toString((OptionHandler)this, (String)"modelStorage", (Object)this.getModelStorage(), (String)"storage: ");
                break;
            }
            default: {
                throw new IllegalStateException("Unhandled location type: " + this.getModelLoadingType());
            }
        }
        return result;
    }

    public void setFlowContext(Actor value) {
        this.m_ModelLoader.setFlowContext(value);
    }

    public Actor getFlowContext() {
        return this.m_ModelLoader.getFlowContext();
    }

    public abstract Class accepts();

    public abstract Class generates();

    protected String initModel() {
        this.m_Wrapper.initDevice(this.m_DeviceType, this.m_GPUDeviceID);
        MessageCollection errors = new MessageCollection();
        this.m_ModelLoader.setDevice(this.m_Wrapper.getDevice());
        Function model = (Function)this.m_ModelLoader.getModel(errors);
        if (model == null) {
            return errors.toString();
        }
        this.m_Wrapper.setModel(model);
        try {
            this.m_Wrapper.initModel(this.getNumClasses());
        }
        catch (Exception e) {
            return Utils.handleException((LoggingSupporter)this, (String)"Failed to initialize model!", (Throwable)e);
        }
        return null;
    }

    protected String check(I input) {
        String result;
        if (this.getFlowContext() == null) {
            return "No flow context set!";
        }
        if (!this.m_Wrapper.isInitialized() && (result = this.initModel()) != null) {
            return result;
        }
        return null;
    }

    protected abstract O doApplyModel(I var1);

    public O applyModel(I input) {
        String msg = this.check(input);
        if (msg != null) {
            throw new IllegalStateException("Failed check: " + msg);
        }
        return this.doApplyModel(input);
    }
}

