package weka.classifiers.meta;

import adams.core.Utils;
import adams.core.base.BaseHostname;
import adams.core.option.OptionUtils;
import gnu.trove.list.array.TByteArrayList;
import java.io.InputStream;
import java.net.ServerSocket;
import java.net.Socket;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Enumeration;
import java.util.Vector;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.meta.socketfacade.AbstractDataPreparation;
import weka.classifiers.meta.socketfacade.Simple;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;

/* loaded from: input_file:weka/classifiers/meta/SocketFacade.class */
public class SocketFacade extends AbstractClassifier {
    private static final long serialVersionUID = -7557824847573090857L;
    protected BaseHostname m_Remote = getDefaultRemote();
    protected BaseHostname m_Local = getDefaultLocal();
    protected int m_Timeout = getDefaultTimeout();
    protected AbstractDataPreparation m_Preparation = getDefaultPreparation();
    protected boolean m_SkipTrain;
    protected transient ServerSocket m_Server;

    public String globalInfo() {
        return "Uses sockets to communicate with a process for training and making predictions.\nNB: This classifier cannot be evaluated in parallel, as the local port, which receives the results, can only be bound once.";
    }

    public Enumeration listOptions() {
        Vector vector = new Vector();
        vector.addElement(new Option("\tThe address of the remote host.\n\t(default: " + getDefaultRemote() + ")", "remote", 1, "-remote <host:port>"));
        vector.addElement(new Option("\tThe return address for the remote host to use.\n\t(default: " + getDefaultLocal() + ")", "local", 1, "-local <host:port>"));
        vector.addElement(new Option("\tThe timeout for sockets in milli-second.\n\t(default: " + getDefaultTimeout() + ")", "timeout", 1, "-timeout <int>"));
        vector.addElement(new Option("\tThe scheme for preparing and parsing the data.\n\t(default: " + Utils.classToString(getDefaultPreparation()) + ")", "preparation", 1, "-preparation <classname + options>"));
        vector.addElement(new Option("\tWhether to skip the training process (eg pre-built model).\n\t(default: train not skipped)", "skip-train", 0, "-skip-train"));
        return vector.elements();
    }

    public void setOptions(String[] strArr) throws Exception {
        String option = weka.core.Utils.getOption("remote", strArr);
        if (option.isEmpty()) {
            setRemote(getDefaultRemote());
        } else {
            setRemote(new BaseHostname(option));
        }
        String option2 = weka.core.Utils.getOption("local", strArr);
        if (option2.isEmpty()) {
            setLocal(getDefaultLocal());
        } else {
            setLocal(new BaseHostname(option2));
        }
        String option3 = weka.core.Utils.getOption("timeout", strArr);
        if (option3.isEmpty()) {
            setTimeout(getDefaultTimeout());
        } else {
            setTimeout(Integer.parseInt(option3));
        }
        String option4 = weka.core.Utils.getOption("preparation", strArr);
        if (option4.isEmpty()) {
            setPreparation(getDefaultPreparation());
        } else {
            setPreparation((AbstractDataPreparation) OptionUtils.forCommandLine(AbstractDataPreparation.class, option4));
        }
        setSkipTrain(weka.core.Utils.getFlag("skip-train", strArr));
        super.setOptions(strArr);
    }

    public String[] getOptions() {
        ArrayList arrayList = new ArrayList();
        arrayList.add("-remote");
        arrayList.add("" + getRemote());
        arrayList.add("-local");
        arrayList.add("" + getLocal());
        arrayList.add("-timeout");
        arrayList.add("" + getTimeout());
        arrayList.add("-preparation");
        arrayList.add(OptionUtils.getCommandLine(getPreparation()));
        if (getSkipTrain()) {
            arrayList.add("-skip-train");
        }
        arrayList.addAll(Arrays.asList(super.getOptions()));
        return (String[]) arrayList.toArray(new String[arrayList.size()]);
    }

    protected BaseHostname getDefaultRemote() {
        return new BaseHostname("127.0.0.1:8000");
    }

    public void setRemote(BaseHostname baseHostname) {
        this.m_Remote = baseHostname;
    }

    public BaseHostname getRemote() {
        return this.m_Remote;
    }

    public String remoteTipText() {
        return "The address of the remote process.";
    }

    protected BaseHostname getDefaultLocal() {
        return new BaseHostname("127.0.0.1:8001");
    }

    public void setLocal(BaseHostname baseHostname) {
        this.m_Local = baseHostname;
    }

    public BaseHostname getLocal() {
        return this.m_Local;
    }

    public String localTipText() {
        return "The return address for the remote process to use.";
    }

    protected int getDefaultTimeout() {
        return 3000;
    }

    public void setTimeout(int i) {
        this.m_Timeout = i;
    }

    public int getTimeout() {
        return this.m_Timeout;
    }

    public String timeoutTipText() {
        return "The timeout in milli-second for waiting on responses from the process.";
    }

    protected AbstractDataPreparation getDefaultPreparation() {
        return new Simple();
    }

    public void setPreparation(AbstractDataPreparation abstractDataPreparation) {
        this.m_Preparation = abstractDataPreparation;
    }

    public AbstractDataPreparation getPreparation() {
        return this.m_Preparation;
    }

    public String preparationTipText() {
        return "The data preparation scheme to use for sending/receiving the data.";
    }

    public void setSkipTrain(boolean z) {
        this.m_SkipTrain = z;
    }

    public boolean getSkipTrain() {
        return this.m_SkipTrain;
    }

    public String skipTrainTipText() {
        return "If enabled, the training is skipped; useful when using a pre-built model.";
    }

    protected synchronized void initServer() throws Exception {
        if (this.m_Server == null) {
            this.m_Server = new ServerSocket(this.m_Local.portValue());
            this.m_Server.setSoTimeout(this.m_Timeout);
        }
    }

    protected synchronized void closeServer() {
        if (this.m_Server != null) {
            try {
                this.m_Server.close();
                this.m_Server = null;
            } catch (Exception e) {
            }
        }
    }

    protected ServerSocket getServer() throws Exception {
        initServer();
        return this.m_Server;
    }

    protected byte[] receive() throws Exception {
        initServer();
        Socket accept = this.m_Server.accept();
        InputStream inputStream = accept.getInputStream();
        TByteArrayList tByteArrayList = new TByteArrayList();
        while (true) {
            int read = inputStream.read();
            if (read == -1) {
                accept.close();
                closeServer();
                return tByteArrayList.toArray();
            }
            tByteArrayList.add((byte) read);
        }
    }

    protected byte[] send(byte[] bArr) throws Exception {
        initServer();
        Socket socket = new Socket(this.m_Remote.hostnameValue(), this.m_Remote.portValue());
        socket.setSoTimeout(this.m_Timeout);
        socket.getOutputStream().write(bArr);
        socket.getOutputStream().flush();
        socket.close();
        return receive();
    }

    public Capabilities getCapabilities() {
        Capabilities capabilities = new Capabilities(this);
        capabilities.enableAll();
        return capabilities;
    }

    public void buildClassifier(Instances instances) throws Exception {
        if (this.m_SkipTrain) {
            return;
        }
        try {
            String parseTrain = this.m_Preparation.parseTrain(send(this.m_Preparation.prepareTrain(instances, this)));
            if (parseTrain != null) {
                throw new Exception("Failed to perform remote build:\n" + parseTrain);
            }
        } finally {
            closeServer();
        }
    }

    public double classifyInstance(Instance instance) throws Exception {
        try {
            double parseClassify = this.m_Preparation.parseClassify(send(this.m_Preparation.prepareClassify(instance, this)));
            closeServer();
            return parseClassify;
        } catch (Throwable th) {
            closeServer();
            throw th;
        }
    }

    public double[] distributionForInstance(Instance instance) throws Exception {
        try {
            return this.m_Preparation.parseDistribution(send(this.m_Preparation.prepareDistribution(instance, this)), instance.numClasses());
        } finally {
            closeServer();
        }
    }

    public String toString() {
        return OptionUtils.getCommandLine(this);
    }
}
