package adams.flow.transformer;

import adams.core.QuickInfoHelper;
import adams.flow.container.DL4JPredictionContainer;
import adams.flow.core.Token;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.dataset.DataSet;

/* loaded from: input_file:adams/flow/transformer/DL4JScoring.class */
public class DL4JScoring extends AbstractProcessDL4JDatasetWithModel<Model> {
    private static final long serialVersionUID = -3019442578354930841L;
    protected boolean m_AddRegularizationTerms;

    public String globalInfo() {
        return "Uses a serialized model to perform predictions on the data being passed through.\nThe model can also be obtained from a callable actor, if the model file is pointing to a directory.";
    }

    @Override // adams.flow.transformer.AbstractProcessDL4JDatasetWithModel
    public void defineOptions() {
        super.defineOptions();
        this.m_OptionManager.add("add-regularization-terms", "addRegularizationTerms", false);
    }

    public void setAddRegularizationTerms(boolean z) {
        this.m_AddRegularizationTerms = z;
        reset();
    }

    public boolean getAddRegularizationTerms() {
        return this.m_AddRegularizationTerms;
    }

    public String addRegularizationTermsTipText() {
        return "Whether to add regularization terms.";
    }

    @Override // adams.flow.transformer.AbstractProcessDL4JDatasetWithModel
    public String getQuickInfo() {
        String quickInfo = super.getQuickInfo();
        String quickInfoHelper = QuickInfoHelper.toString(this, "addRegularizationTerms", this.m_AddRegularizationTerms, "add regularization terms", ", ");
        if (quickInfoHelper != null) {
            quickInfo = quickInfo + quickInfoHelper;
        }
        return quickInfo;
    }

    @Override // adams.flow.transformer.AbstractProcessDL4JDatasetWithModel
    public Class[] generates() {
        return new Class[]{DL4JPredictionContainer.class};
    }

    @Override // adams.flow.transformer.AbstractProcessDL4JDatasetWithModel
    protected Token processDataset(DataSet dataSet) throws Exception {
        Token token = null;
        if (this.m_Model instanceof MultiLayerNetwork) {
            token = new Token(new DL4JPredictionContainer(dataSet, ((MultiLayerNetwork) this.m_Model).scoreExamples(dataSet, this.m_AddRegularizationTerms)));
        } else {
            getLogger().severe("Can only use " + MultiLayerNetwork.class.getName() + " for scoring!");
        }
        return token;
    }
}
