package adams.flow.transformer;

import adams.core.MessageCollection;
import adams.core.QuickInfoHelper;
import adams.core.Randomizable;
import adams.core.Range;
import adams.core.VariableName;
import adams.core.option.AbstractOption;
import adams.data.spreadsheet.Row;
import adams.data.spreadsheet.SpreadSheet;
import adams.event.VariableChangeEvent;
import adams.flow.container.DL4JModelContainer;
import adams.flow.core.CallableActorHelper;
import adams.flow.core.CallableActorReference;
import adams.flow.core.Token;
import adams.flow.core.VariableMonitor;
import adams.flow.provenance.ActorType;
import adams.flow.provenance.Provenance;
import adams.flow.provenance.ProvenanceContainer;
import adams.flow.provenance.ProvenanceInformation;
import adams.flow.provenance.ProvenanceSupporter;
import adams.flow.source.DL4JModelConfigurator;
import adams.ml.dl4j.DataSetHelper;
import adams.ml.dl4j.EvaluationStatistic;
import adams.ml.dl4j.datasetiterator.ShufflingDataSetIterator;
import adams.ml.dl4j.iterationlistener.IterationListenerConfigurator;
import adams.ml.dl4j.model.ModelConfigurator;
import adams.ml.dl4j.trainstopcriterion.AbstractTrainStopCriterion;
import adams.ml.dl4j.trainstopcriterion.MaxEpoch;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Hashtable;
import java.util.Iterator;
import java.util.Map;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.eval.RegressionEvaluation;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.dataset.DataSet;

/* loaded from: input_file:adams/flow/transformer/DL4JTrainModel.class */
public class DL4JTrainModel extends AbstractTransformer implements ProvenanceSupporter, Randomizable, VariableMonitor {
    private static final long serialVersionUID = -3019442578354930841L;
    public static final String BACKUP_MODEL = "model";
    public static final String BACKUP_BEST_MODEL = "best model";
    public static final String BACKUP_EPOCH = "epoch";
    public static final String BACKUP_TRAINDATA = "traindata";
    public static final String BACKUP_TESTDATA = "testdata";
    public static final String BACKUP_BEST_STATISTICS = "best statistics";
    public static final String BACKUP_TRAINING_FINISHED = "training finished";
    protected CallableActorReference m_Model;
    protected Model m_ActualModel;
    protected Model m_BestModel;
    protected Map<String, Double> m_BestStatistics;
    protected AbstractTrainStopCriterion m_TrainStop;
    protected int m_MiniBatchSize;
    protected long m_Seed;
    protected IterationListenerConfigurator[] m_IterationListeners;
    protected int m_OutputInterval;
    protected VariableName m_VariableName;
    protected double m_TestPercentage;
    protected DL4JEvaluationType m_Type;
    protected boolean m_OutputBestModel;
    protected EvaluationStatistic[] m_StatisticValues;
    protected Range m_ClassIndex;
    protected Range m_RegressionColumns;
    protected int m_Epoch;
    protected DataSet m_TrainData;
    protected DataSet m_TestData;
    protected boolean m_TrainingFinished;

    public String globalInfo() {
        return "Trains a model based on the incoming dataset and outputs the built model alongside the dataset (in a model container).\nThe model can be reset using the monitor variable option, i.e, whenever this variable changes value, the model gets reset. Useful when training sequentially on multiple datasets (using the file name as monitor variable).";
    }

    public void defineOptions() {
        super.defineOptions();
        this.m_OptionManager.add("model", "model", new CallableActorReference(DL4JModelConfigurator.class.getSimpleName()));
        this.m_OptionManager.add("train-stop", "trainStop", new MaxEpoch());
        this.m_OptionManager.add("mini-batch-size", "miniBatchSize", -1, -1, (Number) null);
        this.m_OptionManager.add("seed", "seed", 1L);
        this.m_OptionManager.add("iteration-listener", "iterationListeners", new IterationListenerConfigurator[0]);
        this.m_OptionManager.add("output-interval", "outputInterval", -1, -1, (Number) null);
        this.m_OptionManager.add("var-name", "variableName", new VariableName());
        this.m_OptionManager.add("test-percentage", "testPercentage", Double.valueOf(0.0d), Double.valueOf(0.0d), Double.valueOf(1.0d));
        this.m_OptionManager.add("type", "type", DL4JEvaluationType.CLASSIFICATION);
        this.m_OptionManager.add("output-best-model", "outputBestModel", false);
        this.m_OptionManager.add("statistic", "statisticValues", new EvaluationStatistic[]{EvaluationStatistic.ACCURACY, EvaluationStatistic.F1});
        this.m_OptionManager.add("index", "classIndex", new Range("first"));
        this.m_OptionManager.add("regression-columns", "regressionColumns", new Range("last"));
    }

    public void setModel(CallableActorReference callableActorReference) {
        this.m_Model = callableActorReference;
        reset();
    }

    public CallableActorReference getModel() {
        return this.m_Model;
    }

    public String modelTipText() {
        return "The model to train on the input data.";
    }

    public void setTrainStop(AbstractTrainStopCriterion abstractTrainStopCriterion) {
        this.m_TrainStop = abstractTrainStopCriterion;
        reset();
    }

    public AbstractTrainStopCriterion getTrainStop() {
        return this.m_TrainStop;
    }

    public String trainStopTipText() {
        return "The criterion for stopping training.";
    }

    public void setMiniBatchSize(int i) {
        this.m_MiniBatchSize = i;
        reset();
    }

    public int getMiniBatchSize() {
        return this.m_MiniBatchSize;
    }

    public String miniBatchSizeTipText() {
        return "The mini-batch size to use; -1 to turn off.";
    }

    public void setSeed(long j) {
        this.m_Seed = j;
        reset();
    }

    public long getSeed() {
        return this.m_Seed;
    }

    public String seedTipText() {
        return "The seed value to use for randomization.";
    }

    public void setIterationListeners(IterationListenerConfigurator[] iterationListenerConfiguratorArr) {
        this.m_IterationListeners = iterationListenerConfiguratorArr;
        reset();
    }

    public IterationListenerConfigurator[] getIterationListeners() {
        return this.m_IterationListeners;
    }

    public String iterationListenersTipText() {
        return "The iteration listeners to use (configurators).";
    }

    public void setOutputInterval(int i) {
        this.m_OutputInterval = i;
        reset();
    }

    public int getOutputInterval() {
        return this.m_OutputInterval;
    }

    public String outputIntervalTipText() {
        return "The interval (of epochs) to output the model (use <1 to turn off).";
    }

    public void setVariableName(VariableName variableName) {
        this.m_VariableName = variableName;
        reset();
    }

    public VariableName getVariableName() {
        return this.m_VariableName;
    }

    public String variableNameTipText() {
        return "The variable to monitor.";
    }

    public void setTestPercentage(double d) {
        if (getOptionManager().isValid("testPercentage", Double.valueOf(d))) {
            this.m_TestPercentage = d;
            reset();
        }
    }

    public double getTestPercentage() {
        return this.m_TestPercentage;
    }

    public String testPercentageTipText() {
        return "The percentage (0-1) of the training data to set aside for evaluating the model; no testing performed if 0.";
    }

    public void setType(DL4JEvaluationType dL4JEvaluationType) {
        this.m_Type = dL4JEvaluationType;
        reset();
    }

    public DL4JEvaluationType getType() {
        return this.m_Type;
    }

    public String typeTipText() {
        return "The type of evaluation to perform.";
    }

    public void setOutputBestModel(boolean z) {
        this.m_OutputBestModel = z;
        reset();
    }

    public boolean getOutputBestModel() {
        return this.m_OutputBestModel;
    }

    public String outputBestModelTipText() {
        return "If enabled and testing is performed, the best model found so far is output in the container as well.";
    }

    public void setStatisticValues(EvaluationStatistic[] evaluationStatisticArr) {
        this.m_StatisticValues = evaluationStatisticArr;
        reset();
    }

    public EvaluationStatistic[] getStatisticValues() {
        return this.m_StatisticValues;
    }

    public String statisticValuesTipText() {
        return "The evaluation values to extract and turn into a spreadsheet.";
    }

    public void setClassIndex(Range range) {
        this.m_ClassIndex = range;
        reset();
    }

    public Range getClassIndex() {
        return this.m_ClassIndex;
    }

    public String classIndexTipText() {
        return "The range of class label indices (eg used for AUC).";
    }

    public void setRegressionColumns(Range range) {
        this.m_RegressionColumns = range;
        reset();
    }

    public Range getRegressionColumns() {
        return this.m_RegressionColumns;
    }

    public String regressionColumnsTipText() {
        return "The range of columns to get regression statistics for.";
    }

    public String getQuickInfo() {
        return ((((QuickInfoHelper.toString(this, "model", this.m_Model, "model: ") + QuickInfoHelper.toString(this, "trainStop", this.m_TrainStop, ", stop: ")) + QuickInfoHelper.toString(this, "miniBatchSize", Integer.valueOf(this.m_MiniBatchSize), ", minibatch: ")) + QuickInfoHelper.toString(this, "variableName", this.m_VariableName.paddedValue(), ", monitor: ")) + QuickInfoHelper.toString(this, "testPercentage", Double.valueOf(this.m_TestPercentage), ", test %: ")) + QuickInfoHelper.toString(this, "outputBestModel", this.m_OutputBestModel, BACKUP_BEST_MODEL, ", ");
    }

    protected void pruneBackup() {
        super.pruneBackup();
        pruneBackup("model");
        pruneBackup(BACKUP_BEST_MODEL);
        pruneBackup(BACKUP_BEST_STATISTICS);
        pruneBackup(BACKUP_EPOCH);
        pruneBackup(BACKUP_TRAINDATA);
        pruneBackup(BACKUP_TESTDATA);
        pruneBackup(BACKUP_TRAINING_FINISHED);
    }

    protected Hashtable<String, Object> backupState() {
        Hashtable<String, Object> backupState = super.backupState();
        if (this.m_ActualModel != null) {
            backupState.put("model", this.m_ActualModel);
        }
        if (this.m_BestModel != null) {
            backupState.put(BACKUP_BEST_MODEL, this.m_BestModel);
        }
        if (this.m_BestStatistics != null) {
            backupState.put(BACKUP_BEST_STATISTICS, this.m_BestStatistics);
        }
        backupState.put(BACKUP_EPOCH, Integer.valueOf(this.m_Epoch));
        if (this.m_TrainData != null) {
            backupState.put(BACKUP_TRAINDATA, this.m_TrainData);
        }
        if (this.m_TestData != null) {
            backupState.put(BACKUP_TESTDATA, this.m_TestData);
        }
        backupState.put(BACKUP_TRAINING_FINISHED, Boolean.valueOf(this.m_TrainingFinished));
        return backupState;
    }

    protected void restoreState(Hashtable<String, Object> hashtable) {
        if (hashtable.containsKey("model")) {
            this.m_ActualModel = (Model) hashtable.get("model");
            hashtable.remove("model");
        }
        if (hashtable.containsKey(BACKUP_BEST_MODEL)) {
            this.m_BestModel = (Model) hashtable.get(BACKUP_BEST_MODEL);
            hashtable.remove(BACKUP_BEST_MODEL);
        }
        if (hashtable.containsKey(BACKUP_BEST_STATISTICS)) {
            this.m_BestStatistics = (Map) hashtable.get(BACKUP_BEST_STATISTICS);
            hashtable.remove(BACKUP_BEST_STATISTICS);
        }
        if (hashtable.containsKey(BACKUP_EPOCH)) {
            this.m_Epoch = ((Integer) hashtable.get(BACKUP_EPOCH)).intValue();
            hashtable.remove(BACKUP_EPOCH);
        }
        if (hashtable.containsKey(BACKUP_TRAINDATA)) {
            this.m_TrainData = (DataSet) hashtable.get(BACKUP_TRAINDATA);
            hashtable.remove(BACKUP_TRAINDATA);
        }
        if (hashtable.containsKey(BACKUP_TESTDATA)) {
            this.m_TestData = (DataSet) hashtable.get(BACKUP_TESTDATA);
            hashtable.remove(BACKUP_TESTDATA);
        }
        if (hashtable.containsKey(BACKUP_TRAINING_FINISHED)) {
            this.m_TrainingFinished = ((Boolean) hashtable.get(BACKUP_TRAINING_FINISHED)).booleanValue();
            hashtable.remove(BACKUP_TRAINING_FINISHED);
        }
        super.restoreState(hashtable);
    }

    public void variableChanged(VariableChangeEvent variableChangeEvent) {
        super.variableChanged(variableChangeEvent);
        if ((variableChangeEvent.getType() == VariableChangeEvent.Type.MODIFIED || variableChangeEvent.getType() == VariableChangeEvent.Type.ADDED) && variableChangeEvent.getName().equals(this.m_VariableName.getValue())) {
            resetModel();
            if (isLoggingEnabled()) {
                getLogger().info("Reset model");
            }
        }
    }

    protected void reset() {
        super.reset();
        resetModel();
    }

    protected void resetModel() {
        if (this.m_ActualModel != null) {
            this.m_ActualModel.clear();
        }
        this.m_ActualModel = null;
        if (this.m_BestModel != null) {
            this.m_BestModel.clear();
        }
        this.m_BestModel = null;
        this.m_BestStatistics = new HashMap();
        this.m_Epoch = 0;
        this.m_TrainData = null;
        this.m_TestData = null;
        this.m_TrainingFinished = false;
    }

    public Class[] accepts() {
        return new Class[]{DataSet.class};
    }

    public Class[] generates() {
        return new Class[]{DL4JModelContainer.class};
    }

    protected ModelConfigurator getModelConfiguratorInstance() throws Exception {
        MessageCollection messageCollection = new MessageCollection();
        ModelConfigurator modelConfigurator = (ModelConfigurator) CallableActorHelper.getSetup(ModelConfigurator.class, this.m_Model, this, messageCollection);
        if (modelConfigurator != null) {
            return modelConfigurator;
        }
        if (messageCollection.isEmpty()) {
            throw new IllegalStateException("Failed to obtain model configurator from '" + this.m_Model + "'!");
        }
        throw new IllegalStateException("Failed to obtain model configurator from '" + this.m_Model + "':\n" + messageCollection);
    }

    /* JADX WARN: Multi-variable type inference failed */
    protected void updateBestModel(Evaluation evaluation, RegressionEvaluation regressionEvaluation) {
        DL4JEvaluationValues dL4JEvaluationValues = new DL4JEvaluationValues();
        dL4JEvaluationValues.setStatisticValues(this.m_StatisticValues);
        dL4JEvaluationValues.setClassIndex(this.m_ClassIndex);
        dL4JEvaluationValues.setRegressionColumns(this.m_RegressionColumns);
        dL4JEvaluationValues.input(new Token(evaluation != null ? evaluation : regressionEvaluation));
        if (dL4JEvaluationValues.execute() == null) {
            SpreadSheet spreadSheet = (SpreadSheet) dL4JEvaluationValues.output().getPayload();
            boolean z = true;
            if (!this.m_BestStatistics.isEmpty()) {
                Iterator it = spreadSheet.rows().iterator();
                while (true) {
                    if (!it.hasNext()) {
                        break;
                    }
                    Row row = (Row) it.next();
                    String content = row.getCell(0).getContent();
                    if (EvaluationStatistic.valueOf((AbstractOption) null, content).compare(row.getCell(1).toDouble().doubleValue(), this.m_BestStatistics.get(content).doubleValue()) > 0) {
                        z = false;
                        break;
                    }
                }
            }
            if (z) {
                this.m_BestModel = this.m_ActualModel;
                for (Row row2 : spreadSheet.rows()) {
                    this.m_BestStatistics.put(row2.getCell(0).getContent(), row2.getCell(1).toDouble());
                }
            }
        }
    }

    protected String iterate() {
        DL4JModelContainer dL4JModelContainer;
        String str = null;
        Evaluation evaluation = null;
        RegressionEvaluation regressionEvaluation = null;
        MessageCollection messageCollection = new MessageCollection();
        do {
            try {
                this.m_Epoch++;
                if (isLoggingEnabled() && this.m_Epoch % 100 == 0) {
                    getLogger().info("#epoch: " + this.m_Epoch);
                }
                if (this.m_ActualModel instanceof MultiLayerNetwork) {
                    if (this.m_MiniBatchSize < 1) {
                        this.m_ActualModel.fit(this.m_TrainData);
                    } else {
                        ShufflingDataSetIterator shufflingDataSetIterator = new ShufflingDataSetIterator(this.m_TrainData, this.m_MiniBatchSize, (int) this.m_Seed);
                        while (shufflingDataSetIterator.hasNext()) {
                            this.m_ActualModel.fit(shufflingDataSetIterator.m15next());
                        }
                    }
                } else if (this.m_MiniBatchSize < 1) {
                    this.m_ActualModel.fit(this.m_TrainData.getFeatureMatrix());
                } else {
                    ShufflingDataSetIterator shufflingDataSetIterator2 = new ShufflingDataSetIterator(this.m_TrainData, this.m_MiniBatchSize, (int) this.m_Seed);
                    while (shufflingDataSetIterator2.hasNext() && !isStopped()) {
                        this.m_ActualModel.fit(shufflingDataSetIterator2.m15next().getFeatureMatrix());
                    }
                }
                if ((this.m_OutputInterval <= 0 || this.m_Epoch % this.m_OutputInterval != 0) && !isStopped()) {
                    this.m_TrainingFinished = this.m_TrainStop.checkStopping(new DL4JModelContainer(this.m_ActualModel, this.m_TrainData, Integer.valueOf(this.m_Epoch)), messageCollection);
                }
            } catch (Exception e) {
                this.m_OutputToken = null;
                str = handleException("Failed to process data (epoch: " + this.m_Epoch + "):", e);
            }
        } while (!this.m_TrainingFinished);
        if (this.m_TestData != null && !isStopped()) {
            switch (this.m_Type) {
                case CLASSIFICATION:
                    evaluation = new Evaluation(this.m_TrainData.numOutcomes());
                    evaluation.eval(this.m_TestData.getLabels(), this.m_ActualModel.output(this.m_TestData.getFeatureMatrix(), Layer.TrainingMode.TEST));
                    break;
                case REGRESSION:
                    regressionEvaluation = new RegressionEvaluation(this.m_TrainData.numOutcomes());
                    regressionEvaluation.eval(this.m_TestData.getLabels(), this.m_ActualModel.output(this.m_TestData.getFeatureMatrix(), Layer.TrainingMode.TEST));
                    break;
                default:
                    throw new IllegalStateException("Unhandled evaluation type: " + this.m_Type);
            }
            updateBestModel(evaluation, regressionEvaluation);
        }
        if (!isStopped()) {
            if (evaluation != null) {
                dL4JModelContainer = new DL4JModelContainer(this.m_ActualModel, this.m_TrainData, Integer.valueOf(this.m_Epoch), evaluation, this.m_OutputBestModel ? this.m_BestModel : null, this.m_BestStatistics);
            } else if (regressionEvaluation != null) {
                dL4JModelContainer = new DL4JModelContainer(this.m_ActualModel, this.m_TrainData, Integer.valueOf(this.m_Epoch), regressionEvaluation, this.m_OutputBestModel ? this.m_BestModel : null, this.m_BestStatistics);
            } else {
                dL4JModelContainer = new DL4JModelContainer(this.m_ActualModel, this.m_TrainData, Integer.valueOf(this.m_Epoch), null, this.m_OutputBestModel ? this.m_BestModel : null, this.m_BestStatistics);
            }
            this.m_OutputToken = new Token(dL4JModelContainer);
            if (!this.m_TrainingFinished) {
                this.m_TrainingFinished = this.m_TrainStop.checkStopping(dL4JModelContainer, messageCollection);
            }
            if (this.m_TrainingFinished) {
                dL4JModelContainer.setValue(DL4JModelContainer.VALUE_TRAIN_STOP_MESSAGES, messageCollection.toList().toArray(new String[messageCollection.size()]));
            }
        }
        if (this.m_OutputToken != null) {
            updateProvenance(this.m_OutputToken);
        }
        return str;
    }

    protected String doExecute() {
        String str = null;
        try {
            this.m_TrainingFinished = false;
            DataSet dataSet = (DataSet) this.m_InputToken.getPayload();
            if (this.m_TestPercentage > 0.0d) {
                DataSet[] split = DataSetHelper.split(dataSet, 1.0d - this.m_TestPercentage, Long.valueOf(this.m_Seed));
                this.m_TrainData = split[0];
                this.m_TestData = split[1];
                if (isLoggingEnabled()) {
                    getLogger().info("Splitting data into train/test using " + this.m_TestPercentage + " for training.");
                }
            } else {
                this.m_TrainData = dataSet;
                this.m_TestData = null;
                if (isLoggingEnabled()) {
                    getLogger().info("Using all data for training.");
                }
            }
            if (this.m_ActualModel == null) {
                this.m_ActualModel = getModelConfiguratorInstance().configureModel(this.m_TrainData.numInputs(), this.m_TrainData.numOutcomes());
                if (this.m_ActualModel == null) {
                    str = "Failed to obtain model?";
                }
            }
            if (str == null) {
                ArrayList arrayList = new ArrayList();
                for (IterationListenerConfigurator iterationListenerConfigurator : this.m_IterationListeners) {
                    iterationListenerConfigurator.setFlowContext(this);
                    arrayList.addAll(iterationListenerConfigurator.configureIterationListeners());
                }
                this.m_ActualModel.setListeners(arrayList);
                this.m_ActualModel.init();
            }
        } catch (Exception e) {
            this.m_OutputToken = null;
            str = handleException("Failed to process data:", e);
        }
        if (str == null) {
            this.m_TrainStop.setFlowContext(this);
            this.m_TrainStop.start();
            str = iterate();
        }
        return str;
    }

    public void updateProvenance(ProvenanceContainer provenanceContainer) {
        if (Provenance.getSingleton().isEnabled()) {
            if (this.m_InputToken.hasProvenance()) {
                provenanceContainer.setProvenance(this.m_InputToken.getProvenance().getClone());
            }
            provenanceContainer.addProvenance(new ProvenanceInformation(ActorType.MODEL_GENERATOR, this.m_InputToken.getPayload().getClass(), this, this.m_OutputToken.getPayload().getClass()));
        }
    }

    public boolean hasPendingOutput() {
        return super.hasPendingOutput() || (this.m_Epoch > 0 && !this.m_TrainingFinished);
    }

    public Token output() {
        if (this.m_OutputToken == null && !this.m_TrainingFinished) {
            iterate();
        }
        Token token = this.m_OutputToken;
        this.m_OutputToken = null;
        return token;
    }

    public void wrapUp() {
        super.wrapUp();
        resetModel();
    }
}
