/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.functions;

import adams.core.Utils;
import adams.core.base.BaseHostname;
import adams.data.wekapyroproxy.AbstractCommunicationProcessor;
import adams.data.wekapyroproxy.NullCommunicationProcessor;
import adams.env.Environment;
import adams.flow.core.Actor;
import adams.flow.core.ActorUtils;
import adams.flow.core.FlowContextHandler;
import adams.flow.standalone.PyroNameServer;
import net.razorvine.pyro.Config;
import net.razorvine.pyro.NameServerProxy;
import weka.classifiers.Classifier;
import weka.classifiers.simple.AbstractSimpleClassifier;
import weka.core.BatchPredictor;
import weka.core.Capabilities;
import weka.core.CapabilitiesHandler;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.PyroProxyObject;

public class PyroProxy
extends AbstractSimpleClassifier
implements PyroProxyObject,
FlowContextHandler,
BatchPredictor {
    private static final long serialVersionUID = -4578812400878994526L;
    protected BaseHostname m_NameServer;
    protected String m_RemoteObjectName;
    protected String m_MethodNameTrain;
    protected String m_MethodNamePrediction;
    protected String m_ModelName;
    protected AbstractCommunicationProcessor m_Communication;
    protected boolean m_PerformTraining;
    protected int m_BatchSize;
    protected transient Actor m_FlowContext;
    protected transient PyroNameServer m_NameServerActor;
    protected transient NameServerProxy m_NameServerProxy;
    protected transient net.razorvine.pyro.PyroProxy m_RemoteObject;

    public String globalInfo() {
        return "Proxy for a Python model using Pyro4 for communication.\n\nIf a flow context is set and a " + Utils.classToString(PyroNameServer.class) + " can provide a Pyro NameServerProxy instance, then this will override the namerserver settings defined by the classifier.\nFor more information see on Pyro:\nhttps://github.com/irmen/Pyro4";
    }

    public void defineOptions() {
        super.defineOptions();
        this.m_OptionManager.add("name-server", "nameServer", (Object)new BaseHostname("localhost:9090"));
        this.m_OptionManager.add("remote-object-name", "remoteObjectName", (Object)"");
        this.m_OptionManager.add("perform-training", "performTraining", (Object)false);
        this.m_OptionManager.add("method-name-train", "methodNameTrain", (Object)"");
        this.m_OptionManager.add("method-name-prediction", "methodNamePrediction", (Object)"");
        this.m_OptionManager.add("model-name", "modelName", (Object)"");
        this.m_OptionManager.add("communication", "communication", (Object)new NullCommunicationProcessor());
        this.m_OptionManager.add("batch-size", "batchSize", (Object)"1");
    }

    public void setFlowContext(Actor value) {
        this.m_FlowContext = value;
    }

    public Actor getFlowContext() {
        return this.m_FlowContext;
    }

    @Override
    public void setNameServer(BaseHostname value) {
        this.m_NameServer = value;
        this.reset();
    }

    @Override
    public BaseHostname getNameServer() {
        return this.m_NameServer;
    }

    @Override
    public String nameServerTipText() {
        return "The address of the Pyro nameserver.";
    }

    @Override
    public void setRemoteObjectName(String value) {
        this.m_RemoteObjectName = value;
        this.reset();
    }

    @Override
    public String getRemoteObjectName() {
        return this.m_RemoteObjectName;
    }

    @Override
    public String remoteObjectNameTipText() {
        return "The name of the remote object to use.";
    }

    public void setPerformTraining(boolean value) {
        this.m_PerformTraining = value;
        this.reset();
    }

    public boolean getPerformTraining() {
        return this.m_PerformTraining;
    }

    public String performTrainingTipText() {
        return "If enabled, then training is performed.";
    }

    public void setMethodNameTrain(String value) {
        this.m_MethodNameTrain = value;
        this.reset();
    }

    public String getMethodNameTrain() {
        return this.m_MethodNameTrain;
    }

    public String methodNameTrainTipText() {
        return "The name of the method to call for training.";
    }

    public void setMethodNamePrediction(String value) {
        this.m_MethodNamePrediction = value;
        this.reset();
    }

    public String getMethodNamePrediction() {
        return this.m_MethodNamePrediction;
    }

    public String methodNamePredictionTipText() {
        return "The name of the method to call for predictions.";
    }

    public void setModelName(String value) {
        this.m_ModelName = value;
        this.reset();
    }

    public String getModelName() {
        return this.m_ModelName;
    }

    public String modelNameTipText() {
        return "The name of the model to use.";
    }

    @Override
    public void setCommunication(AbstractCommunicationProcessor value) {
        this.m_Communication = value;
        this.reset();
    }

    @Override
    public AbstractCommunicationProcessor getCommunication() {
        return this.m_Communication;
    }

    @Override
    public String communicationTipText() {
        return "Handles the communication with the remote model.";
    }

    public void setBatchSize(String value) {
        int intValue = Integer.parseInt(value);
        if (this.getOptionManager().isValid("batchSize", (Number)intValue)) {
            this.m_BatchSize = intValue;
            this.reset();
        }
    }

    public String getBatchSize() {
        return "" + this.m_BatchSize;
    }

    public String batchSizeTipText() {
        return "The batch size to use for generating multiple predictions (if possible).";
    }

    public boolean implementsMoreEfficientBatchPrediction() {
        return this.m_Communication.supportsBatchPredictions();
    }

    public Capabilities getCapabilities() {
        Capabilities result = new Capabilities((CapabilitiesHandler)this);
        result.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        result.enable(Capabilities.Capability.STRING_ATTRIBUTES);
        result.enable(Capabilities.Capability.NUMERIC_CLASS);
        result.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        result.setMinimumNumberInstances(0);
        return result;
    }

    public void buildClassifier(Instances data) throws Exception {
        long end;
        long start;
        this.getCapabilities().testWithFail(data);
        if (this.m_RemoteObjectName.trim().isEmpty()) {
            throw new IllegalStateException("Remote object name is empty!");
        }
        if (this.m_PerformTraining && this.m_MethodNameTrain.trim().isEmpty()) {
            throw new IllegalStateException("Method name (train) is empty!");
        }
        if (this.m_MethodNamePrediction.trim().isEmpty()) {
            throw new IllegalStateException("Method name (prediction) is empty!");
        }
        if (this.m_ModelName.trim().isEmpty()) {
            throw new IllegalStateException("Model name is empty!");
        }
        this.m_Communication.initialize(this, data);
        this.m_NameServerProxy = null;
        this.m_NameServerActor = null;
        if (this.m_FlowContext != null) {
            if (this.isLoggingEnabled()) {
                this.getLogger().info("Using flow context (" + this.m_FlowContext.getFullName() + ") to determine nameserver...");
            }
            this.m_NameServerActor = (PyroNameServer)ActorUtils.findClosestType((Actor)this.m_FlowContext, PyroNameServer.class, (boolean)true);
            if (this.m_NameServerActor != null) {
                this.m_NameServerProxy = this.m_NameServerActor.getNameServer();
            }
            if (this.isLoggingEnabled()) {
                this.getLogger().info("Determined nameserver through flow context: " + (this.m_NameServerProxy != null));
            }
        }
        if (this.m_NameServerProxy == null) {
            try {
                if (this.isLoggingEnabled()) {
                    this.getLogger().info("Connecting to: " + this.m_NameServer);
                }
                start = System.currentTimeMillis();
                this.m_NameServerProxy = NameServerProxy.locateNS((String)this.m_NameServer.hostnameValue(), (int)this.m_NameServer.portValue(Config.NS_PORT), null);
                end = System.currentTimeMillis();
                if (this.isLoggingEnabled()) {
                    this.getLogger().info("duration/nameserver: " + (double)(end - start) / 1000.0);
                }
            }
            catch (Exception e) {
                throw new Exception("Failed to connect to Pyro nameserver: " + this.m_NameServer, e);
            }
        }
        try {
            if (this.isLoggingEnabled()) {
                this.getLogger().info("Obtaining remote object: " + this.m_RemoteObjectName);
            }
            start = System.currentTimeMillis();
            this.m_RemoteObject = new net.razorvine.pyro.PyroProxy(this.m_NameServerProxy.lookup(this.m_RemoteObjectName));
            end = System.currentTimeMillis();
            if (this.isLoggingEnabled()) {
                this.getLogger().info("duration/remoteobject: " + (double)(end - start) / 1000.0);
            }
        }
        catch (Exception e) {
            throw new Exception("Failed to obtain remote object: " + this.m_RemoteObjectName, e);
        }
        if (this.m_PerformTraining) {
            Object train = this.m_Communication.convertDataset(this, data);
            start = System.currentTimeMillis();
            this.m_RemoteObject.call(this.m_MethodNameTrain, new Object[]{train});
            end = System.currentTimeMillis();
            if (this.isLoggingEnabled()) {
                this.getLogger().info("duration/buildClassifier: " + (double)(end - start) / 1000.0);
            }
        }
    }

    public double[] distributionForInstance(Instance instance) throws Exception {
        if (this.m_RemoteObject == null) {
            throw new IllegalStateException("No remote object available for remote calls!");
        }
        Object data = this.m_Communication.convertInstance(this, instance);
        long start = System.currentTimeMillis();
        Object prediction = this.m_RemoteObject.call(this.m_MethodNamePrediction, new Object[]{data});
        long end = System.currentTimeMillis();
        if (this.isLoggingEnabled()) {
            this.getLogger().info("duration/distributionForInstance: " + (double)(end - start) / 1000.0);
        }
        try {
            return this.m_Communication.parsePrediction(this, prediction);
        }
        catch (Exception e) {
            throw new Exception("Failed to process prediction:\n" + prediction, e);
        }
    }

    public double[][] distributionsForInstances(Instances insts) throws Exception {
        if (this.m_RemoteObject == null) {
            throw new IllegalStateException("No remote object available for remote calls!");
        }
        Object data = this.m_Communication.convertDataset(this, insts);
        long start = System.currentTimeMillis();
        Object predictions = this.m_RemoteObject.call(this.m_MethodNamePrediction, new Object[]{data});
        long end = System.currentTimeMillis();
        if (this.isLoggingEnabled()) {
            this.getLogger().info("duration/distributionForInstance: " + (double)(end - start) / 1000.0);
        }
        try {
            return this.m_Communication.parsePredictions(this, predictions);
        }
        catch (Exception e) {
            throw new Exception("Failed to process predictions:\n" + predictions, e);
        }
    }

    public String toString() {
        return "Flow context: " + (this.m_FlowContext == null ? "-none-" : this.m_FlowContext.getFullName()) + "\nNameserver: " + this.m_NameServer + "\nRemote object name: " + this.m_RemoteObjectName + "\nPerform training: " + this.m_PerformTraining + "\nMethod name (train): " + this.m_MethodNameTrain + "\nMethod name (prediction): " + this.m_MethodNamePrediction + "\nModel name: " + this.m_ModelName + "\nConnected: " + (this.m_RemoteObject != null);
    }

    public static void main(String[] args) throws Exception {
        Environment.setEnvironmentClass(Environment.class);
        PyroProxy.runClassifier((Classifier)new PyroProxy(), (String[])args);
    }
}

