/*
 * Decompiled with CFR 0.152.
 */
package adams.flow.transformer;

import adams.core.QuickInfoHelper;
import adams.core.option.OptionHandler;
import adams.flow.container.DL4JEvaluationContainer;
import adams.flow.container.DL4JModelContainer;
import adams.flow.core.Actor;
import adams.flow.core.CallableActorReference;
import adams.flow.core.Token;
import adams.flow.provenance.ActorType;
import adams.flow.provenance.Provenance;
import adams.flow.provenance.ProvenanceContainer;
import adams.flow.provenance.ProvenanceInformation;
import adams.flow.provenance.ProvenanceSupporter;
import adams.flow.source.CallableSource;
import adams.flow.transformer.AbstractDL4JModelEvaluator;
import adams.flow.transformer.DL4JEvaluationType;
import org.deeplearning4j.eval.BaseEvaluation;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.eval.RegressionEvaluation;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.dataset.DataSet;

public class DL4JTestSetEvaluator
extends AbstractDL4JModelEvaluator
implements ProvenanceSupporter {
    private static final long serialVersionUID = -8528709957864675275L;
    protected CallableActorReference m_Testset;
    protected DL4JEvaluationType m_Type;

    public String globalInfo() {
        return "Evaluates a trained model (obtained from input) on the dataset obtained from the callable actor.";
    }

    @Override
    public void defineOptions() {
        super.defineOptions();
        this.m_OptionManager.add("testset", "testset", (Object)new CallableActorReference("Testset"));
        this.m_OptionManager.add("type", "type", (Object)DL4JEvaluationType.CLASSIFICATION);
    }

    public String getQuickInfo() {
        String result = QuickInfoHelper.toString((OptionHandler)this, (String)"testset", (Object)this.m_Testset, (String)"test: ");
        result = result + QuickInfoHelper.toString((OptionHandler)this, (String)"type", (Object)((Object)this.m_Type), (String)", type: ");
        return result;
    }

    public void setTestset(CallableActorReference value) {
        this.m_Testset = value;
        this.reset();
    }

    public CallableActorReference getTestset() {
        return this.m_Testset;
    }

    public String testsetTipText() {
        return "The callable actor to use for obtaining the test set.";
    }

    public void setType(DL4JEvaluationType value) {
        this.m_Type = value;
        this.reset();
    }

    public DL4JEvaluationType getType() {
        return this.m_Type;
    }

    public String typeTipText() {
        return "The type of evaluation to perform.";
    }

    public Class[] accepts() {
        return new Class[]{Model.class, DL4JModelContainer.class};
    }

    protected String doExecute() {
        String result;
        block10: {
            result = null;
            try {
                Evaluation eval;
                Model model;
                block12: {
                    block11: {
                        DataSet test = null;
                        CallableSource gs = new CallableSource();
                        gs.setCallableName(this.m_Testset);
                        gs.setParent(this.getParent());
                        gs.setUp();
                        gs.execute();
                        Token output = gs.output();
                        if (output != null) {
                            test = (DataSet)output.getPayload();
                        } else {
                            result = "No test set available!";
                        }
                        gs.wrapUp();
                        if (result != null) break block10;
                        model = this.m_InputToken.getPayload() instanceof Model ? (Model)this.m_InputToken.getPayload() : (Model)((DL4JModelContainer)((Object)this.m_InputToken.getPayload())).getValue("Model");
                        eval = null;
                        if (!(model instanceof MultiLayerNetwork)) break block11;
                        switch (this.m_Type) {
                            case CLASSIFICATION: {
                                eval = new Evaluation(test.numOutcomes());
                                eval.eval(test.getLabels(), ((MultiLayerNetwork)model).output(test.getFeatureMatrix(), Layer.TrainingMode.TEST));
                                break block12;
                            }
                            case REGRESSION: {
                                eval = new RegressionEvaluation(test.numOutcomes());
                                eval.eval(test.getLabels(), ((MultiLayerNetwork)model).output(test.getFeatureMatrix(), Layer.TrainingMode.TEST));
                                break block12;
                            }
                            default: {
                                throw new IllegalStateException("Unhandled evaluation type: " + (Object)((Object)this.m_Type));
                            }
                        }
                    }
                    result = "Can only evaluate " + MultiLayerNetwork.class.getName() + "!";
                }
                if (eval != null) {
                    this.m_OutputToken = this.m_AlwaysUseContainer ? new Token((Object)new DL4JEvaluationContainer((BaseEvaluation)eval, model)) : new Token((Object)eval.stats());
                }
            }
            catch (Exception e) {
                this.m_OutputToken = null;
                result = this.handleException("Failed to evaluate: ", e);
            }
        }
        if (this.m_OutputToken != null) {
            this.updateProvenance((ProvenanceContainer)this.m_OutputToken);
        }
        return result;
    }

    public void updateProvenance(ProvenanceContainer cont) {
        if (Provenance.getSingleton().isEnabled()) {
            if (this.m_InputToken.hasProvenance()) {
                cont.setProvenance(this.m_InputToken.getProvenance().getClone());
            }
            cont.addProvenance(new ProvenanceInformation(ActorType.EVALUATOR, this.m_InputToken.getPayload().getClass(), (Actor)this, this.m_OutputToken.getPayload().getClass()));
        }
    }
}

