/*
 * Decompiled with CFR 0.152.
 */
package adams.flow.transformer;

import adams.core.QuickInfoHelper;
import adams.core.option.OptionHandler;
import adams.flow.container.DL4JPredictionContainer;
import adams.flow.core.Token;
import adams.flow.transformer.AbstractProcessDL4JDatasetWithModel;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;

public class DL4JScoring
extends AbstractProcessDL4JDatasetWithModel<Model> {
    private static final long serialVersionUID = -3019442578354930841L;
    protected boolean m_AddRegularizationTerms;

    public String globalInfo() {
        return "Uses a serialized model to perform predictions on the data being passed through.\nThe model can also be obtained from a callable actor, if the model file is pointing to a directory.";
    }

    @Override
    public void defineOptions() {
        super.defineOptions();
        this.m_OptionManager.add("add-regularization-terms", "addRegularizationTerms", (Object)false);
    }

    public void setAddRegularizationTerms(boolean value) {
        this.m_AddRegularizationTerms = value;
        this.reset();
    }

    public boolean getAddRegularizationTerms() {
        return this.m_AddRegularizationTerms;
    }

    public String addRegularizationTermsTipText() {
        return "Whether to add regularization terms.";
    }

    @Override
    public String getQuickInfo() {
        String result = super.getQuickInfo();
        String value = QuickInfoHelper.toString((OptionHandler)this, (String)"addRegularizationTerms", (boolean)this.m_AddRegularizationTerms, (String)"add regularization terms", (String)", ");
        if (value != null) {
            result = result + value;
        }
        return result;
    }

    @Override
    public Class[] generates() {
        return new Class[]{DL4JPredictionContainer.class};
    }

    @Override
    protected Token processDataset(DataSet data) throws Exception {
        Token result = null;
        if (this.m_Model instanceof MultiLayerNetwork) {
            INDArray scores = ((MultiLayerNetwork)this.m_Model).scoreExamples(data, this.m_AddRegularizationTerms);
            DL4JPredictionContainer cont = new DL4JPredictionContainer(data, scores);
            result = new Token((Object)cont);
        } else {
            this.getLogger().severe("Can only use " + MultiLayerNetwork.class.getName() + " for scoring!");
        }
        return result;
    }
}

