/*
 * Decompiled with CFR 0.152.
 */
package adams.flow.transformer;

import adams.core.MessageCollection;
import adams.core.QuickInfoHelper;
import adams.core.Randomizable;
import adams.core.Range;
import adams.core.VariableName;
import adams.core.option.AbstractOption;
import adams.core.option.OptionHandler;
import adams.data.spreadsheet.Row;
import adams.data.spreadsheet.SpreadSheet;
import adams.event.VariableChangeEvent;
import adams.flow.container.DL4JModelContainer;
import adams.flow.core.Actor;
import adams.flow.core.CallableActorHelper;
import adams.flow.core.CallableActorReference;
import adams.flow.core.Token;
import adams.flow.core.VariableMonitor;
import adams.flow.provenance.ActorType;
import adams.flow.provenance.Provenance;
import adams.flow.provenance.ProvenanceContainer;
import adams.flow.provenance.ProvenanceInformation;
import adams.flow.provenance.ProvenanceSupporter;
import adams.flow.source.DL4JModelConfigurator;
import adams.flow.transformer.AbstractTransformer;
import adams.flow.transformer.DL4JEvaluationType;
import adams.flow.transformer.DL4JEvaluationValues;
import adams.ml.dl4j.DataSetHelper;
import adams.ml.dl4j.EvaluationStatistic;
import adams.ml.dl4j.datasetiterator.ShufflingDataSetIterator;
import adams.ml.dl4j.iterationlistener.IterationListenerConfigurator;
import adams.ml.dl4j.model.ModelConfigurator;
import adams.ml.dl4j.trainstopcriterion.AbstractTrainStopCriterion;
import adams.ml.dl4j.trainstopcriterion.MaxEpoch;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Hashtable;
import java.util.Map;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.eval.RegressionEvaluation;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.api.IterationListener;
import org.nd4j.linalg.dataset.api.DataSet;

public class DL4JTrainModel
extends AbstractTransformer
implements ProvenanceSupporter,
Randomizable,
VariableMonitor {
    private static final long serialVersionUID = -3019442578354930841L;
    public static final String BACKUP_MODEL = "model";
    public static final String BACKUP_BEST_MODEL = "best model";
    public static final String BACKUP_EPOCH = "epoch";
    public static final String BACKUP_TRAINDATA = "traindata";
    public static final String BACKUP_TESTDATA = "testdata";
    public static final String BACKUP_BEST_STATISTICS = "best statistics";
    public static final String BACKUP_TRAINING_FINISHED = "training finished";
    protected CallableActorReference m_Model;
    protected Model m_ActualModel;
    protected Model m_BestModel;
    protected Map<String, Double> m_BestStatistics;
    protected AbstractTrainStopCriterion m_TrainStop;
    protected int m_MiniBatchSize;
    protected long m_Seed;
    protected IterationListenerConfigurator[] m_IterationListeners;
    protected int m_OutputInterval;
    protected VariableName m_VariableName;
    protected double m_TestPercentage;
    protected DL4JEvaluationType m_Type;
    protected boolean m_OutputBestModel;
    protected EvaluationStatistic[] m_StatisticValues;
    protected Range m_ClassIndex;
    protected Range m_RegressionColumns;
    protected int m_Epoch;
    protected org.nd4j.linalg.dataset.DataSet m_TrainData;
    protected org.nd4j.linalg.dataset.DataSet m_TestData;
    protected boolean m_TrainingFinished;

    public String globalInfo() {
        return "Trains a model based on the incoming dataset and outputs the built model alongside the dataset (in a model container).\nThe model can be reset using the monitor variable option, i.e, whenever this variable changes value, the model gets reset. Useful when training sequentially on multiple datasets (using the file name as monitor variable).";
    }

    public void defineOptions() {
        super.defineOptions();
        this.m_OptionManager.add(BACKUP_MODEL, BACKUP_MODEL, (Object)new CallableActorReference(DL4JModelConfigurator.class.getSimpleName()));
        this.m_OptionManager.add("train-stop", "trainStop", (Object)new MaxEpoch());
        this.m_OptionManager.add("mini-batch-size", "miniBatchSize", (Object)-1, (Number)-1, null);
        this.m_OptionManager.add("seed", "seed", (Object)1L);
        this.m_OptionManager.add("iteration-listener", "iterationListeners", (Object)new IterationListenerConfigurator[0]);
        this.m_OptionManager.add("output-interval", "outputInterval", (Object)-1, (Number)-1, null);
        this.m_OptionManager.add("var-name", "variableName", (Object)new VariableName());
        this.m_OptionManager.add("test-percentage", "testPercentage", (Object)0.0, (Number)0.0, (Number)1.0);
        this.m_OptionManager.add("type", "type", (Object)DL4JEvaluationType.CLASSIFICATION);
        this.m_OptionManager.add("output-best-model", "outputBestModel", (Object)false);
        this.m_OptionManager.add("statistic", "statisticValues", (Object)new EvaluationStatistic[]{EvaluationStatistic.ACCURACY, EvaluationStatistic.F1});
        this.m_OptionManager.add("index", "classIndex", (Object)new Range("first"));
        this.m_OptionManager.add("regression-columns", "regressionColumns", (Object)new Range("last"));
    }

    public void setModel(CallableActorReference value) {
        this.m_Model = value;
        this.reset();
    }

    public CallableActorReference getModel() {
        return this.m_Model;
    }

    public String modelTipText() {
        return "The model to train on the input data.";
    }

    public void setTrainStop(AbstractTrainStopCriterion value) {
        this.m_TrainStop = value;
        this.reset();
    }

    public AbstractTrainStopCriterion getTrainStop() {
        return this.m_TrainStop;
    }

    public String trainStopTipText() {
        return "The criterion for stopping training.";
    }

    public void setMiniBatchSize(int value) {
        this.m_MiniBatchSize = value;
        this.reset();
    }

    public int getMiniBatchSize() {
        return this.m_MiniBatchSize;
    }

    public String miniBatchSizeTipText() {
        return "The mini-batch size to use; -1 to turn off.";
    }

    public void setSeed(long value) {
        this.m_Seed = value;
        this.reset();
    }

    public long getSeed() {
        return this.m_Seed;
    }

    public String seedTipText() {
        return "The seed value to use for randomization.";
    }

    public void setIterationListeners(IterationListenerConfigurator[] value) {
        this.m_IterationListeners = value;
        this.reset();
    }

    public IterationListenerConfigurator[] getIterationListeners() {
        return this.m_IterationListeners;
    }

    public String iterationListenersTipText() {
        return "The iteration listeners to use (configurators).";
    }

    public void setOutputInterval(int value) {
        this.m_OutputInterval = value;
        this.reset();
    }

    public int getOutputInterval() {
        return this.m_OutputInterval;
    }

    public String outputIntervalTipText() {
        return "The interval (of epochs) to output the model (use <1 to turn off).";
    }

    public void setVariableName(VariableName value) {
        this.m_VariableName = value;
        this.reset();
    }

    public VariableName getVariableName() {
        return this.m_VariableName;
    }

    public String variableNameTipText() {
        return "The variable to monitor.";
    }

    public void setTestPercentage(double value) {
        if (this.getOptionManager().isValid("testPercentage", (Number)value)) {
            this.m_TestPercentage = value;
            this.reset();
        }
    }

    public double getTestPercentage() {
        return this.m_TestPercentage;
    }

    public String testPercentageTipText() {
        return "The percentage (0-1) of the training data to set aside for evaluating the model; no testing performed if 0.";
    }

    public void setType(DL4JEvaluationType value) {
        this.m_Type = value;
        this.reset();
    }

    public DL4JEvaluationType getType() {
        return this.m_Type;
    }

    public String typeTipText() {
        return "The type of evaluation to perform.";
    }

    public void setOutputBestModel(boolean value) {
        this.m_OutputBestModel = value;
        this.reset();
    }

    public boolean getOutputBestModel() {
        return this.m_OutputBestModel;
    }

    public String outputBestModelTipText() {
        return "If enabled and testing is performed, the best model found so far is output in the container as well.";
    }

    public void setStatisticValues(EvaluationStatistic[] value) {
        this.m_StatisticValues = value;
        this.reset();
    }

    public EvaluationStatistic[] getStatisticValues() {
        return this.m_StatisticValues;
    }

    public String statisticValuesTipText() {
        return "The evaluation values to extract and turn into a spreadsheet.";
    }

    public void setClassIndex(Range value) {
        this.m_ClassIndex = value;
        this.reset();
    }

    public Range getClassIndex() {
        return this.m_ClassIndex;
    }

    public String classIndexTipText() {
        return "The range of class label indices (eg used for AUC).";
    }

    public void setRegressionColumns(Range value) {
        this.m_RegressionColumns = value;
        this.reset();
    }

    public Range getRegressionColumns() {
        return this.m_RegressionColumns;
    }

    public String regressionColumnsTipText() {
        return "The range of columns to get regression statistics for.";
    }

    public String getQuickInfo() {
        String result = QuickInfoHelper.toString((OptionHandler)this, (String)BACKUP_MODEL, (Object)this.m_Model, (String)"model: ");
        result = result + QuickInfoHelper.toString((OptionHandler)this, (String)"trainStop", (Object)((Object)this.m_TrainStop), (String)", stop: ");
        result = result + QuickInfoHelper.toString((OptionHandler)this, (String)"miniBatchSize", (Object)this.m_MiniBatchSize, (String)", minibatch: ");
        result = result + QuickInfoHelper.toString((OptionHandler)this, (String)"variableName", (Object)this.m_VariableName.paddedValue(), (String)", monitor: ");
        result = result + QuickInfoHelper.toString((OptionHandler)this, (String)"testPercentage", (Object)this.m_TestPercentage, (String)", test %: ");
        result = result + QuickInfoHelper.toString((OptionHandler)this, (String)"outputBestModel", (boolean)this.m_OutputBestModel, (String)BACKUP_BEST_MODEL, (String)", ");
        return result;
    }

    protected void pruneBackup() {
        super.pruneBackup();
        this.pruneBackup(BACKUP_MODEL);
        this.pruneBackup(BACKUP_BEST_MODEL);
        this.pruneBackup(BACKUP_BEST_STATISTICS);
        this.pruneBackup(BACKUP_EPOCH);
        this.pruneBackup(BACKUP_TRAINDATA);
        this.pruneBackup(BACKUP_TESTDATA);
        this.pruneBackup(BACKUP_TRAINING_FINISHED);
    }

    protected Hashtable<String, Object> backupState() {
        Hashtable result = super.backupState();
        if (this.m_ActualModel != null) {
            result.put(BACKUP_MODEL, this.m_ActualModel);
        }
        if (this.m_BestModel != null) {
            result.put(BACKUP_BEST_MODEL, this.m_BestModel);
        }
        if (this.m_BestStatistics != null) {
            result.put(BACKUP_BEST_STATISTICS, this.m_BestStatistics);
        }
        result.put(BACKUP_EPOCH, this.m_Epoch);
        if (this.m_TrainData != null) {
            result.put(BACKUP_TRAINDATA, this.m_TrainData);
        }
        if (this.m_TestData != null) {
            result.put(BACKUP_TESTDATA, this.m_TestData);
        }
        result.put(BACKUP_TRAINING_FINISHED, this.m_TrainingFinished);
        return result;
    }

    protected void restoreState(Hashtable<String, Object> state) {
        if (state.containsKey(BACKUP_MODEL)) {
            this.m_ActualModel = (Model)state.get(BACKUP_MODEL);
            state.remove(BACKUP_MODEL);
        }
        if (state.containsKey(BACKUP_BEST_MODEL)) {
            this.m_BestModel = (Model)state.get(BACKUP_BEST_MODEL);
            state.remove(BACKUP_BEST_MODEL);
        }
        if (state.containsKey(BACKUP_BEST_STATISTICS)) {
            this.m_BestStatistics = (Map)state.get(BACKUP_BEST_STATISTICS);
            state.remove(BACKUP_BEST_STATISTICS);
        }
        if (state.containsKey(BACKUP_EPOCH)) {
            this.m_Epoch = (Integer)state.get(BACKUP_EPOCH);
            state.remove(BACKUP_EPOCH);
        }
        if (state.containsKey(BACKUP_TRAINDATA)) {
            this.m_TrainData = (org.nd4j.linalg.dataset.DataSet)state.get(BACKUP_TRAINDATA);
            state.remove(BACKUP_TRAINDATA);
        }
        if (state.containsKey(BACKUP_TESTDATA)) {
            this.m_TestData = (org.nd4j.linalg.dataset.DataSet)state.get(BACKUP_TESTDATA);
            state.remove(BACKUP_TESTDATA);
        }
        if (state.containsKey(BACKUP_TRAINING_FINISHED)) {
            this.m_TrainingFinished = (Boolean)state.get(BACKUP_TRAINING_FINISHED);
            state.remove(BACKUP_TRAINING_FINISHED);
        }
        super.restoreState(state);
    }

    public void variableChanged(VariableChangeEvent e) {
        super.variableChanged(e);
        if ((e.getType() == VariableChangeEvent.Type.MODIFIED || e.getType() == VariableChangeEvent.Type.ADDED) && e.getName().equals(this.m_VariableName.getValue())) {
            this.resetModel();
            if (this.isLoggingEnabled()) {
                this.getLogger().info("Reset model");
            }
        }
    }

    protected void reset() {
        super.reset();
        this.resetModel();
    }

    protected void resetModel() {
        if (this.m_ActualModel != null) {
            this.m_ActualModel.clear();
        }
        this.m_ActualModel = null;
        if (this.m_BestModel != null) {
            this.m_BestModel.clear();
        }
        this.m_BestModel = null;
        this.m_BestStatistics = new HashMap<String, Double>();
        this.m_Epoch = 0;
        this.m_TrainData = null;
        this.m_TestData = null;
        this.m_TrainingFinished = false;
    }

    public Class[] accepts() {
        return new Class[]{org.nd4j.linalg.dataset.DataSet.class};
    }

    public Class[] generates() {
        return new Class[]{DL4JModelContainer.class};
    }

    protected ModelConfigurator getModelConfiguratorInstance() throws Exception {
        MessageCollection errors = new MessageCollection();
        ModelConfigurator result = (ModelConfigurator)CallableActorHelper.getSetup(ModelConfigurator.class, (CallableActorReference)this.m_Model, (Actor)this, (MessageCollection)errors);
        if (result == null) {
            if (errors.isEmpty()) {
                throw new IllegalStateException("Failed to obtain model configurator from '" + this.m_Model + "'!");
            }
            throw new IllegalStateException("Failed to obtain model configurator from '" + this.m_Model + "':\n" + errors);
        }
        return result;
    }

    protected void updateBestModel(Evaluation evalCls, RegressionEvaluation evalReg) {
        DL4JEvaluationValues eval = new DL4JEvaluationValues();
        eval.setStatisticValues(this.m_StatisticValues);
        eval.setClassIndex(this.m_ClassIndex);
        eval.setRegressionColumns(this.m_RegressionColumns);
        eval.input(new Token(evalCls != null ? evalCls : evalReg));
        String result = eval.execute();
        if (result == null) {
            Double statValue;
            String statName;
            SpreadSheet stats = (SpreadSheet)eval.output().getPayload();
            boolean better = true;
            if (!this.m_BestStatistics.isEmpty()) {
                for (Row row : stats.rows()) {
                    statName = row.getCell(0).getContent();
                    statValue = row.getCell(1).toDouble();
                    EvaluationStatistic statEnum = EvaluationStatistic.valueOf((AbstractOption)null, statName);
                    if (statEnum.compare(statValue, this.m_BestStatistics.get(statName)) <= 0) continue;
                    better = false;
                    break;
                }
            }
            if (better) {
                this.m_BestModel = this.m_ActualModel;
                for (Row row : stats.rows()) {
                    statName = row.getCell(0).getContent();
                    statValue = row.getCell(1).toDouble();
                    this.m_BestStatistics.put(statName, statValue);
                }
            }
        }
    }

    protected String iterate() {
        String result = null;
        Evaluation evalCls = null;
        RegressionEvaluation evalReg = null;
        MessageCollection triggers = new MessageCollection();
        try {
            do {
                ShufflingDataSetIterator iter;
                ++this.m_Epoch;
                if (this.isLoggingEnabled() && this.m_Epoch % 100 == 0) {
                    this.getLogger().info("#epoch: " + this.m_Epoch);
                }
                if (this.m_ActualModel instanceof MultiLayerNetwork) {
                    if (this.m_MiniBatchSize < 1) {
                        ((MultiLayerNetwork)this.m_ActualModel).fit((DataSet)this.m_TrainData);
                    } else {
                        iter = new ShufflingDataSetIterator(this.m_TrainData, this.m_MiniBatchSize, (int)this.m_Seed);
                        while (iter.hasNext()) {
                            ((MultiLayerNetwork)this.m_ActualModel).fit((DataSet)iter.next());
                        }
                    }
                } else if (this.m_MiniBatchSize < 1) {
                    this.m_ActualModel.fit(this.m_TrainData.getFeatureMatrix());
                } else {
                    iter = new ShufflingDataSetIterator(this.m_TrainData, this.m_MiniBatchSize, (int)this.m_Seed);
                    while (iter.hasNext() && !this.isStopped()) {
                        this.m_ActualModel.fit(iter.next().getFeatureMatrix());
                    }
                }
                if (this.m_OutputInterval > 0 && this.m_Epoch % this.m_OutputInterval == 0 || this.isStopped()) break;
                this.m_TrainingFinished = this.m_TrainStop.checkStopping(new DL4JModelContainer(this.m_ActualModel, this.m_TrainData, this.m_Epoch), triggers);
            } while (!this.m_TrainingFinished);
            if (this.m_TestData != null && !this.isStopped()) {
                switch (this.m_Type) {
                    case CLASSIFICATION: {
                        evalCls = new Evaluation(this.m_TrainData.numOutcomes());
                        evalCls.eval(this.m_TestData.getLabels(), ((MultiLayerNetwork)this.m_ActualModel).output(this.m_TestData.getFeatureMatrix(), Layer.TrainingMode.TEST));
                        break;
                    }
                    case REGRESSION: {
                        evalReg = new RegressionEvaluation(this.m_TrainData.numOutcomes());
                        evalReg.eval(this.m_TestData.getLabels(), ((MultiLayerNetwork)this.m_ActualModel).output(this.m_TestData.getFeatureMatrix(), Layer.TrainingMode.TEST));
                        break;
                    }
                    default: {
                        throw new IllegalStateException("Unhandled evaluation type: " + (Object)((Object)this.m_Type));
                    }
                }
                this.updateBestModel(evalCls, evalReg);
            }
            if (!this.isStopped()) {
                DL4JModelContainer cont = evalCls != null ? new DL4JModelContainer(this.m_ActualModel, this.m_TrainData, this.m_Epoch, evalCls, this.m_OutputBestModel ? this.m_BestModel : null, this.m_BestStatistics) : (evalReg != null ? new DL4JModelContainer(this.m_ActualModel, this.m_TrainData, this.m_Epoch, evalReg, this.m_OutputBestModel ? this.m_BestModel : null, this.m_BestStatistics) : new DL4JModelContainer(this.m_ActualModel, this.m_TrainData, this.m_Epoch, null, this.m_OutputBestModel ? this.m_BestModel : null, this.m_BestStatistics));
                this.m_OutputToken = new Token((Object)cont);
                if (!this.m_TrainingFinished) {
                    this.m_TrainingFinished = this.m_TrainStop.checkStopping(cont, triggers);
                }
                if (this.m_TrainingFinished) {
                    cont.setValue("Train Stop Messages", triggers.toList().toArray(new String[triggers.size()]));
                }
            }
        }
        catch (Exception e) {
            this.m_OutputToken = null;
            result = this.handleException("Failed to process data (epoch: " + this.m_Epoch + "):", e);
        }
        if (this.m_OutputToken != null) {
            this.updateProvenance((ProvenanceContainer)this.m_OutputToken);
        }
        return result;
    }

    protected String doExecute() {
        String result = null;
        try {
            this.m_TrainingFinished = false;
            org.nd4j.linalg.dataset.DataSet data = (org.nd4j.linalg.dataset.DataSet)this.m_InputToken.getPayload();
            if (this.m_TestPercentage > 0.0) {
                org.nd4j.linalg.dataset.DataSet[] split = DataSetHelper.split(data, 1.0 - this.m_TestPercentage, this.m_Seed);
                this.m_TrainData = split[0];
                this.m_TestData = split[1];
                if (this.isLoggingEnabled()) {
                    this.getLogger().info("Splitting data into train/test using " + this.m_TestPercentage + " for training.");
                }
            } else {
                this.m_TrainData = data;
                this.m_TestData = null;
                if (this.isLoggingEnabled()) {
                    this.getLogger().info("Using all data for training.");
                }
            }
            if (this.m_ActualModel == null) {
                ModelConfigurator conf = this.getModelConfiguratorInstance();
                this.m_ActualModel = conf.configureModel(this.m_TrainData.numInputs(), this.m_TrainData.numOutcomes());
                if (this.m_ActualModel == null) {
                    result = "Failed to obtain model?";
                }
            }
            if (result == null) {
                ArrayList<IterationListener> listeners = new ArrayList<IterationListener>();
                for (IterationListenerConfigurator l : this.m_IterationListeners) {
                    l.setFlowContext((Actor)this);
                    listeners.addAll(l.configureIterationListeners());
                }
                this.m_ActualModel.setListeners(listeners);
                this.m_ActualModel.init();
            }
        }
        catch (Exception e) {
            this.m_OutputToken = null;
            result = this.handleException("Failed to process data:", e);
        }
        if (result == null) {
            this.m_TrainStop.setFlowContext((Actor)this);
            this.m_TrainStop.start();
            result = this.iterate();
        }
        return result;
    }

    public void updateProvenance(ProvenanceContainer cont) {
        if (Provenance.getSingleton().isEnabled()) {
            if (this.m_InputToken.hasProvenance()) {
                cont.setProvenance(this.m_InputToken.getProvenance().getClone());
            }
            cont.addProvenance(new ProvenanceInformation(ActorType.MODEL_GENERATOR, this.m_InputToken.getPayload().getClass(), (Actor)this, this.m_OutputToken.getPayload().getClass()));
        }
    }

    public boolean hasPendingOutput() {
        return super.hasPendingOutput() || this.m_Epoch > 0 && !this.m_TrainingFinished;
    }

    public Token output() {
        if (this.m_OutputToken == null && !this.m_TrainingFinished) {
            this.iterate();
        }
        Token result = this.m_OutputToken;
        this.m_OutputToken = null;
        return result;
    }

    public void wrapUp() {
        super.wrapUp();
        this.resetModel();
    }
}

