package adams.flow.transformer;

import adams.core.Utils;
import adams.flow.core.Token;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import weka.classifiers.functions.DL4JMultiLayerNetwork;
import weka.core.Instances;
import weka.filters.Filter;

/* loaded from: input_file:adams/flow/transformer/DL4JInitWekaClassifier.class */
public class DL4JInitWekaClassifier extends AbstractTransformer {
    private static final long serialVersionUID = 7480758744491379628L;
    public static final String DEFAULT_ERROR = "Requires array of length 3: DL4J multi-layer network, trained Weka filter, original training data.";

    public String globalInfo() {
        return "Uses the incoming array to initialize a " + DL4JMultiLayerNetwork.class + " classifier.\nArray must consist of:\n- DL4J multi-layer network\n- trained Weka filter (used for preprocessing data)\n- original training data";
    }

    public Class[] accepts() {
        return new Class[]{Object[].class};
    }

    public Class[] generates() {
        return new Class[]{DL4JMultiLayerNetwork.class};
    }

    protected String doExecute() {
        String str = null;
        Object[] objArr = (Object[]) this.m_InputToken.getPayload();
        if (objArr.length != 3) {
            str = "Array of length " + objArr.length + "! " + DEFAULT_ERROR;
        } else if (!(objArr[0] instanceof MultiLayerNetwork)) {
            str = "1st array element wrong (found: " + Utils.classToString(objArr[0]) + ")! " + DEFAULT_ERROR;
        } else if (!(objArr[1] instanceof Filter)) {
            str = "2nd array element wrong (found: " + Utils.classToString(objArr[1]) + ")! " + DEFAULT_ERROR;
        } else if (!(objArr[2] instanceof Instances)) {
            str = "3rd array element wrong (found: " + Utils.classToString(objArr[2]) + ")! " + DEFAULT_ERROR;
        }
        if (str == null) {
            DL4JMultiLayerNetwork dL4JMultiLayerNetwork = new DL4JMultiLayerNetwork();
            dL4JMultiLayerNetwork.setTrainedMultiLayerNetwork((MultiLayerNetwork) objArr[0]);
            dL4JMultiLayerNetwork.setTrainedPreFilter((Filter) objArr[1]);
            dL4JMultiLayerNetwork.setTrainingData((Instances) objArr[2]);
            this.m_OutputToken = new Token(dL4JMultiLayerNetwork);
        }
        return str;
    }
}
