package weka.classifiers.functions;

import adams.core.Utils;
import adams.core.base.BaseHostname;
import adams.data.wekapyroproxy.AbstractCommunicationProcessor;
import adams.data.wekapyroproxy.NullCommunicationProcessor;
import adams.env.Environment;
import adams.flow.core.Actor;
import adams.flow.core.ActorUtils;
import adams.flow.core.FlowContextHandler;
import adams.flow.standalone.PyroNameServer;
import net.razorvine.pyro.Config;
import net.razorvine.pyro.NameServerProxy;
import weka.classifiers.simple.AbstractSimpleClassifier;
import weka.core.BatchPredictor;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.PyroProxyObject;

/* loaded from: input_file:weka/classifiers/functions/PyroProxy.class */
public class PyroProxy extends AbstractSimpleClassifier implements PyroProxyObject, FlowContextHandler, BatchPredictor {
    private static final long serialVersionUID = -4578812400878994526L;
    protected BaseHostname m_NameServer;
    protected String m_RemoteObjectName;
    protected String m_MethodNameTrain;
    protected String m_MethodNamePrediction;
    protected String m_ModelName;
    protected AbstractCommunicationProcessor m_Communication;
    protected boolean m_PerformTraining;
    protected int m_BatchSize;
    protected transient Actor m_FlowContext;
    protected transient PyroNameServer m_NameServerActor;
    protected transient NameServerProxy m_NameServerProxy;
    protected transient net.razorvine.pyro.PyroProxy m_RemoteObject;

    public String globalInfo() {
        return "Proxy for a Python model using Pyro4 for communication.\n\nIf a flow context is set and a " + Utils.classToString(PyroNameServer.class) + " can provide a Pyro NameServerProxy instance, then this will override the namerserver settings defined by the classifier.\nFor more information see on Pyro:\nhttps://github.com/irmen/Pyro4";
    }

    public void defineOptions() {
        super.defineOptions();
        this.m_OptionManager.add("name-server", "nameServer", new BaseHostname("localhost:9090"));
        this.m_OptionManager.add("remote-object-name", "remoteObjectName", "");
        this.m_OptionManager.add("perform-training", "performTraining", false);
        this.m_OptionManager.add("method-name-train", "methodNameTrain", "");
        this.m_OptionManager.add("method-name-prediction", "methodNamePrediction", "");
        this.m_OptionManager.add("model-name", "modelName", "");
        this.m_OptionManager.add("communication", "communication", new NullCommunicationProcessor());
        this.m_OptionManager.add("batch-size", "batchSize", "1");
    }

    public void setFlowContext(Actor actor) {
        this.m_FlowContext = actor;
    }

    public Actor getFlowContext() {
        return this.m_FlowContext;
    }

    @Override // weka.core.PyroProxyObject
    public void setNameServer(BaseHostname baseHostname) {
        this.m_NameServer = baseHostname;
        reset();
    }

    @Override // weka.core.PyroProxyObject
    public BaseHostname getNameServer() {
        return this.m_NameServer;
    }

    @Override // weka.core.PyroProxyObject
    public String nameServerTipText() {
        return "The address of the Pyro nameserver.";
    }

    @Override // weka.core.PyroProxyObject
    public void setRemoteObjectName(String str) {
        this.m_RemoteObjectName = str;
        reset();
    }

    @Override // weka.core.PyroProxyObject
    public String getRemoteObjectName() {
        return this.m_RemoteObjectName;
    }

    @Override // weka.core.PyroProxyObject
    public String remoteObjectNameTipText() {
        return "The name of the remote object to use.";
    }

    public void setPerformTraining(boolean z) {
        this.m_PerformTraining = z;
        reset();
    }

    public boolean getPerformTraining() {
        return this.m_PerformTraining;
    }

    public String performTrainingTipText() {
        return "If enabled, then training is performed.";
    }

    public void setMethodNameTrain(String str) {
        this.m_MethodNameTrain = str;
        reset();
    }

    public String getMethodNameTrain() {
        return this.m_MethodNameTrain;
    }

    public String methodNameTrainTipText() {
        return "The name of the method to call for training.";
    }

    public void setMethodNamePrediction(String str) {
        this.m_MethodNamePrediction = str;
        reset();
    }

    public String getMethodNamePrediction() {
        return this.m_MethodNamePrediction;
    }

    public String methodNamePredictionTipText() {
        return "The name of the method to call for predictions.";
    }

    public void setModelName(String str) {
        this.m_ModelName = str;
        reset();
    }

    public String getModelName() {
        return this.m_ModelName;
    }

    public String modelNameTipText() {
        return "The name of the model to use.";
    }

    @Override // weka.core.PyroProxyObject
    public void setCommunication(AbstractCommunicationProcessor abstractCommunicationProcessor) {
        this.m_Communication = abstractCommunicationProcessor;
        reset();
    }

    @Override // weka.core.PyroProxyObject
    public AbstractCommunicationProcessor getCommunication() {
        return this.m_Communication;
    }

    @Override // weka.core.PyroProxyObject
    public String communicationTipText() {
        return "Handles the communication with the remote model.";
    }

    public void setBatchSize(String str) {
        int parseInt = Integer.parseInt(str);
        if (getOptionManager().isValid("batchSize", Integer.valueOf(parseInt))) {
            this.m_BatchSize = parseInt;
            reset();
        }
    }

    public String getBatchSize() {
        return "" + this.m_BatchSize;
    }

    public String batchSizeTipText() {
        return "The batch size to use for generating multiple predictions (if possible).";
    }

    public boolean implementsMoreEfficientBatchPrediction() {
        return this.m_Communication.supportsBatchPredictions();
    }

    public Capabilities getCapabilities() {
        Capabilities capabilities = new Capabilities(this);
        capabilities.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.STRING_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.NUMERIC_CLASS);
        capabilities.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        capabilities.setMinimumNumberInstances(0);
        return capabilities;
    }

    public void buildClassifier(Instances instances) throws Exception {
        getCapabilities().testWithFail(instances);
        if (this.m_RemoteObjectName.trim().isEmpty()) {
            throw new IllegalStateException("Remote object name is empty!");
        }
        if (this.m_PerformTraining && this.m_MethodNameTrain.trim().isEmpty()) {
            throw new IllegalStateException("Method name (train) is empty!");
        }
        if (this.m_MethodNamePrediction.trim().isEmpty()) {
            throw new IllegalStateException("Method name (prediction) is empty!");
        }
        if (this.m_ModelName.trim().isEmpty()) {
            throw new IllegalStateException("Model name is empty!");
        }
        this.m_Communication.initialize(this, instances);
        this.m_NameServerProxy = null;
        this.m_NameServerActor = null;
        if (this.m_FlowContext != null) {
            if (isLoggingEnabled()) {
                getLogger().info("Using flow context (" + this.m_FlowContext.getFullName() + ") to determine nameserver...");
            }
            this.m_NameServerActor = ActorUtils.findClosestType(this.m_FlowContext, PyroNameServer.class, true);
            if (this.m_NameServerActor != null) {
                this.m_NameServerProxy = this.m_NameServerActor.getNameServer();
            }
            if (isLoggingEnabled()) {
                getLogger().info("Determined nameserver through flow context: " + (this.m_NameServerProxy != null));
            }
        }
        if (this.m_NameServerProxy == null) {
            try {
                if (isLoggingEnabled()) {
                    getLogger().info("Connecting to: " + this.m_NameServer);
                }
                long currentTimeMillis = System.currentTimeMillis();
                this.m_NameServerProxy = NameServerProxy.locateNS(this.m_NameServer.hostnameValue(), this.m_NameServer.portValue(Config.NS_PORT), (byte[]) null);
                long currentTimeMillis2 = System.currentTimeMillis();
                if (isLoggingEnabled()) {
                    getLogger().info("duration/nameserver: " + ((currentTimeMillis2 - currentTimeMillis) / 1000.0d));
                }
            } catch (Exception e) {
                throw new Exception("Failed to connect to Pyro nameserver: " + this.m_NameServer, e);
            }
        }
        try {
            if (isLoggingEnabled()) {
                getLogger().info("Obtaining remote object: " + this.m_RemoteObjectName);
            }
            long currentTimeMillis3 = System.currentTimeMillis();
            this.m_RemoteObject = new net.razorvine.pyro.PyroProxy(this.m_NameServerProxy.lookup(this.m_RemoteObjectName));
            long currentTimeMillis4 = System.currentTimeMillis();
            if (isLoggingEnabled()) {
                getLogger().info("duration/remoteobject: " + ((currentTimeMillis4 - currentTimeMillis3) / 1000.0d));
            }
            if (this.m_PerformTraining) {
                Object convertDataset = this.m_Communication.convertDataset(this, instances);
                long currentTimeMillis5 = System.currentTimeMillis();
                this.m_RemoteObject.call(this.m_MethodNameTrain, new Object[]{convertDataset});
                long currentTimeMillis6 = System.currentTimeMillis();
                if (isLoggingEnabled()) {
                    getLogger().info("duration/buildClassifier: " + ((currentTimeMillis6 - currentTimeMillis5) / 1000.0d));
                }
            }
        } catch (Exception e2) {
            throw new Exception("Failed to obtain remote object: " + this.m_RemoteObjectName, e2);
        }
    }

    public double[] distributionForInstance(Instance instance) throws Exception {
        if (this.m_RemoteObject == null) {
            throw new IllegalStateException("No remote object available for remote calls!");
        }
        Object convertInstance = this.m_Communication.convertInstance(this, instance);
        long currentTimeMillis = System.currentTimeMillis();
        Object call = this.m_RemoteObject.call(this.m_MethodNamePrediction, new Object[]{convertInstance});
        long currentTimeMillis2 = System.currentTimeMillis();
        if (isLoggingEnabled()) {
            getLogger().info("duration/distributionForInstance: " + ((currentTimeMillis2 - currentTimeMillis) / 1000.0d));
        }
        try {
            return this.m_Communication.parsePrediction(this, call);
        } catch (Exception e) {
            throw new Exception("Failed to process prediction:\n" + call, e);
        }
    }

    public double[][] distributionsForInstances(Instances instances) throws Exception {
        if (this.m_RemoteObject == null) {
            throw new IllegalStateException("No remote object available for remote calls!");
        }
        Object convertDataset = this.m_Communication.convertDataset(this, instances);
        long currentTimeMillis = System.currentTimeMillis();
        Object call = this.m_RemoteObject.call(this.m_MethodNamePrediction, new Object[]{convertDataset});
        long currentTimeMillis2 = System.currentTimeMillis();
        if (isLoggingEnabled()) {
            getLogger().info("duration/distributionForInstance: " + ((currentTimeMillis2 - currentTimeMillis) / 1000.0d));
        }
        try {
            return this.m_Communication.parsePredictions(this, call);
        } catch (Exception e) {
            throw new Exception("Failed to process predictions:\n" + call, e);
        }
    }

    public String toString() {
        return "Flow context: " + (this.m_FlowContext == null ? "-none-" : this.m_FlowContext.getFullName()) + "\nNameserver: " + this.m_NameServer + "\nRemote object name: " + this.m_RemoteObjectName + "\nPerform training: " + this.m_PerformTraining + "\nMethod name (train): " + this.m_MethodNameTrain + "\nMethod name (prediction): " + this.m_MethodNamePrediction + "\nModel name: " + this.m_ModelName + "\nConnected: " + (this.m_RemoteObject != null);
    }

    public static void main(String[] strArr) throws Exception {
        Environment.setEnvironmentClass(Environment.class);
        runClassifier(new PyroProxy(), strArr);
    }
}
