/*
 * Decompiled with CFR 0.152.
 */
package adams.flow.transformer;

import adams.core.QuickInfoHelper;
import adams.core.Randomizable;
import adams.core.option.OptionHandler;
import adams.flow.container.DL4JEvaluationContainer;
import adams.flow.core.Actor;
import adams.flow.core.Token;
import adams.flow.provenance.ActorType;
import adams.flow.provenance.Provenance;
import adams.flow.provenance.ProvenanceContainer;
import adams.flow.provenance.ProvenanceInformation;
import adams.flow.provenance.ProvenanceSupporter;
import adams.flow.transformer.AbstractCallableDL4JModelEvaluator;
import adams.flow.transformer.DL4JEvaluationType;
import adams.ml.dl4j.datasetiterator.ShufflingDataSetIterator;
import adams.ml.dl4j.iterationlistener.IterationListenerConfigurator;
import adams.ml.dl4j.model.ModelConfigurator;
import java.util.ArrayList;
import java.util.Random;
import org.deeplearning4j.eval.BaseEvaluation;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.eval.RegressionEvaluation;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.api.IterationListener;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.KFoldIterator;
import org.nd4j.linalg.factory.Nd4j;

public class DL4JCrossValidationEvaluator
extends AbstractCallableDL4JModelEvaluator
implements Randomizable,
ProvenanceSupporter {
    private static final long serialVersionUID = -1092101024095887007L;
    protected long m_Seed;
    protected int m_Folds;
    protected int m_NumEpochs;
    protected int m_MiniBatchSize;
    protected DL4JEvaluationType m_Type;
    protected IterationListenerConfigurator[] m_IterationListeners;

    public String globalInfo() {
        return "Cross-validates a model on the incoming dataset.\nThe model setup being used in the evaluation is obtained from the callable actor returning a model configurator.";
    }

    @Override
    public void defineOptions() {
        super.defineOptions();
        this.m_OptionManager.add("seed", "seed", (Object)1L);
        this.m_OptionManager.add("folds", "folds", (Object)10, (Number)2, null);
        this.m_OptionManager.add("num-epochs", "numEpochs", (Object)1000, (Number)1, null);
        this.m_OptionManager.add("mini-batch-size", "miniBatchSize", (Object)-1, (Number)-1, null);
        this.m_OptionManager.add("type", "type", (Object)DL4JEvaluationType.CLASSIFICATION);
        this.m_OptionManager.add("iteration-listener", "iterationListeners", (Object)new IterationListenerConfigurator[0]);
    }

    @Override
    public String getQuickInfo() {
        String result = super.getQuickInfo();
        result = result + QuickInfoHelper.toString((OptionHandler)this, (String)"seed", (Object)this.m_Seed, (String)", seed: ");
        result = result + QuickInfoHelper.toString((OptionHandler)this, (String)"folds", (Object)this.m_Folds, (String)", folds: ");
        result = result + QuickInfoHelper.toString((OptionHandler)this, (String)"type", (Object)((Object)this.m_Type), (String)", type: ");
        result = result + QuickInfoHelper.toString((OptionHandler)this, (String)"numEpochs", (Object)this.m_NumEpochs, (String)", epochs: ");
        result = result + QuickInfoHelper.toString((OptionHandler)this, (String)"miniBatchSize", (Object)this.m_MiniBatchSize, (String)", minibatch: ");
        return result;
    }

    @Override
    public String modelTipText() {
        return "The callable model configurator actor to obtain the model from to train and evaluate on the test data.";
    }

    public void setSeed(long value) {
        this.m_Seed = value;
        this.reset();
    }

    public long getSeed() {
        return this.m_Seed;
    }

    public String seedTipText() {
        return "The seed value for the randomization.";
    }

    public void setFolds(int value) {
        if (this.getOptionManager().isValid("folds", (Number)value)) {
            this.m_Folds = value;
            this.reset();
        }
    }

    public int getFolds() {
        return this.m_Folds;
    }

    public String foldsTipText() {
        return "The folds to use.";
    }

    public void setNumEpochs(int value) {
        this.m_NumEpochs = value;
        this.reset();
    }

    public int getNumEpochs() {
        return this.m_NumEpochs;
    }

    public String numEpochsTipText() {
        return "The number of epochs to perform.";
    }

    public void setMiniBatchSize(int value) {
        this.m_MiniBatchSize = value;
        this.reset();
    }

    public int getMiniBatchSize() {
        return this.m_MiniBatchSize;
    }

    public String miniBatchSizeTipText() {
        return "The mini-batch size to use; -1 to turn off.";
    }

    public void setType(DL4JEvaluationType value) {
        this.m_Type = value;
        this.reset();
    }

    public DL4JEvaluationType getType() {
        return this.m_Type;
    }

    public String typeTipText() {
        return "The type of evaluation to perform.";
    }

    public void setIterationListeners(IterationListenerConfigurator[] value) {
        this.m_IterationListeners = value;
        this.reset();
    }

    public IterationListenerConfigurator[] getIterationListeners() {
        return this.m_IterationListeners;
    }

    public String iterationListenersTipText() {
        return "The iteration listeners to use (configurators).";
    }

    public Class[] accepts() {
        return new Class[]{DataSet.class};
    }

    protected String doExecute() {
        String result = null;
        try {
            Model model;
            ModelConfigurator conf = this.getModelConfiguratorInstance();
            if (conf == null) {
                throw new IllegalStateException("Model configurator '" + this.getModel() + "' not found!");
            }
            DataSet data = (DataSet)this.m_InputToken.getPayload();
            Nd4j.shuffle((INDArray)data.getFeatureMatrix(), (Random)new Random(this.m_Seed), (int[])new int[]{1});
            if (data.getLabels() != null) {
                Nd4j.shuffle((INDArray)data.getLabels(), (Random)new Random(this.m_Seed), (int[])new int[]{1});
            }
            if (!(conf.configureModel(data.numInputs(), data.numOutcomes()) instanceof MultiLayerNetwork)) {
                result = "Can only evaluate " + MultiLayerNetwork.class.getName() + "!";
            }
            Evaluation eval = null;
            if (result == null) {
                switch (this.m_Type) {
                    case CLASSIFICATION: {
                        eval = new Evaluation(data.numOutcomes());
                        break;
                    }
                    case REGRESSION: {
                        eval = new RegressionEvaluation(data.numOutcomes());
                        break;
                    }
                    default: {
                        throw new IllegalStateException("Unhandled evaluation type: " + (Object)((Object)this.m_Type));
                    }
                }
                KFoldIterator iter = new KFoldIterator(this.m_Folds, data);
                while (iter.hasNext() && !this.isStopped()) {
                    DataSet train = iter.next();
                    DataSet test = iter.testFold();
                    model = conf.configureModel(data.numInputs(), data.numOutcomes());
                    ArrayList<IterationListener> listeners = new ArrayList<IterationListener>();
                    for (IterationListenerConfigurator l : this.m_IterationListeners) {
                        l.setFlowContext((Actor)this);
                        listeners.addAll(l.configureIterationListeners());
                    }
                    model.setListeners(listeners);
                    for (int i = 0; i < this.m_NumEpochs; ++i) {
                        if (this.m_MiniBatchSize < 1) {
                            ((MultiLayerNetwork)model).fit((org.nd4j.linalg.dataset.api.DataSet)train);
                        } else {
                            ShufflingDataSetIterator shuffle = new ShufflingDataSetIterator(train, this.m_MiniBatchSize, (int)this.m_Seed);
                            while (shuffle.hasNext() && !this.isStopped()) {
                                ((MultiLayerNetwork)model).fit((org.nd4j.linalg.dataset.api.DataSet)shuffle.next());
                            }
                        }
                        if (this.isStopped()) break;
                    }
                    eval.eval(test.getLabels(), ((MultiLayerNetwork)model).output(test.getFeatureMatrix(), Layer.TrainingMode.TEST));
                    model.clear();
                }
                if (this.isStopped()) {
                    eval = null;
                }
            }
            model = conf.configureModel(data.numInputs(), data.numOutcomes());
            if (eval != null) {
                this.m_OutputToken = this.m_AlwaysUseContainer ? new Token((Object)new DL4JEvaluationContainer((BaseEvaluation)eval, model, this.m_NumEpochs)) : new Token((Object)eval.stats());
            }
        }
        catch (Exception e) {
            this.m_OutputToken = null;
            result = this.handleException("Failed to evaluate: ", e);
        }
        if (this.m_OutputToken != null) {
            this.updateProvenance((ProvenanceContainer)this.m_OutputToken);
        }
        return result;
    }

    public void updateProvenance(ProvenanceContainer cont) {
        if (Provenance.getSingleton().isEnabled()) {
            if (this.m_InputToken.hasProvenance()) {
                cont.setProvenance(this.m_InputToken.getProvenance().getClone());
            }
            cont.addProvenance(new ProvenanceInformation(ActorType.EVALUATOR, this.m_InputToken.getPayload().getClass(), (Actor)this, this.m_OutputToken.getPayload().getClass()));
        }
    }
}

