/*
 * Decompiled with CFR 0.152.
 */
package weka.filters.supervised.attribute;

import adams.core.SerializationHelper;
import adams.core.io.PlaceholderFile;
import adams.ml.dl4j.trainstopcriterion.AbstractTrainStopCriterion;
import adams.ml.dl4j.trainstopcriterion.MaxEpoch;
import java.io.File;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.List;
import java.util.Vector;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import weka.classifiers.functions.DL4JMultiLayerNetwork;
import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.RevisionUtils;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;
import weka.core.WekaOptionUtils;
import weka.dl4j.iterators.DefaultInstancesIterator;
import weka.filters.Filter;
import weka.filters.SimpleBatchFilter;
import weka.filters.SupervisedFilter;

public class DL4JMultiLayerNetworkFilter
extends SimpleBatchFilter
implements SupervisedFilter,
TechnicalInformationHandler {
    static final long serialVersionUID = -3335106965521265631L;
    protected static String SEED = "seed";
    protected static String PRE_TRAINED = "pre-trained";
    protected static String MODEL_FILE = "model-file";
    protected static String EPOCHS = "epochs";
    protected static String LAYER = "layer";
    protected Filter m_Filter = null;
    protected DL4JMultiLayerNetwork m_model;
    protected boolean m_model_loaded = false;
    DefaultInstancesIterator m_Iterator = new DefaultInstancesIterator();
    protected int m_epochs = 500;
    protected int m_layer = -1;
    protected int m_seed = 1;
    protected PlaceholderFile m_modelFile = new PlaceholderFile();
    protected boolean m_preTrained = true;

    protected boolean loadModel() {
        if (!this.m_model_loaded) {
            try {
                this.m_model = (DL4JMultiLayerNetwork)SerializationHelper.read((String)this.getModelFile().getAbsolutePath());
                this.m_Filter = this.m_model.getPreFilter();
                this.m_model_loaded = true;
            }
            catch (Exception e) {
                System.err.println(e.toString());
                return false;
            }
            return true;
        }
        return true;
    }

    public String globalInfo() {
        return "\n\n" + this.getTechnicalInformation().toString();
    }

    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation result = new TechnicalInformation(TechnicalInformation.Type.BOOK);
        return result;
    }

    public boolean getPreTrained() {
        return this.m_preTrained;
    }

    public int getSeed() {
        return this.m_seed;
    }

    public PlaceholderFile getModelFile() {
        return this.m_modelFile;
    }

    public int getEpochs() {
        return this.m_epochs;
    }

    public int getLayer() {
        return this.m_layer;
    }

    public void setPreTrained(boolean b) {
        this.m_preTrained = b;
    }

    public void setSeed(int seed) {
        this.m_seed = seed;
    }

    public void setModelFile(PlaceholderFile mf) {
        this.m_modelFile = mf;
    }

    public void setEpochs(int e) {
        this.m_epochs = e;
    }

    public void setLayer(int l) {
        this.m_layer = l;
    }

    public String modelFileTipText() {
        return "Model file to load";
    }

    public String logFileTipText() {
        return "Log file for Weka Classifier";
    }

    public String seedTipText() {
        return "Seed";
    }

    public String preTrainedTipText() {
        return "Is the model already built?";
    }

    public String epochsTipText() {
        return "epochs to train with if not pre-trained";
    }

    public String layerTipText() {
        return "layer to extract features from. -1 is last. Starts from 1. (0 is input layer)";
    }

    public Enumeration<Option> listOptions() {
        Vector result = new Vector();
        WekaOptionUtils.addOption(result, (String)this.modelFileTipText(), (String)("" + this.getModelFile()), (String)MODEL_FILE);
        WekaOptionUtils.addOption(result, (String)this.preTrainedTipText(), (String)("" + this.getPreTrained()), (String)PRE_TRAINED);
        WekaOptionUtils.addOption(result, (String)this.seedTipText(), (String)("" + this.getSeed()), (String)SEED);
        WekaOptionUtils.addOption(result, (String)this.epochsTipText(), (String)("" + this.getEpochs()), (String)EPOCHS);
        WekaOptionUtils.add(result, (Enumeration)super.listOptions());
        return WekaOptionUtils.toEnumeration(result);
    }

    public String[] getOptions() {
        ArrayList result = new ArrayList();
        WekaOptionUtils.add(result, (String)SEED, (int)this.getSeed());
        WekaOptionUtils.add(result, (String)EPOCHS, (int)this.getEpochs());
        WekaOptionUtils.add(result, (String)LAYER, (int)this.getLayer());
        WekaOptionUtils.add(result, (String)PRE_TRAINED, (boolean)this.getPreTrained());
        WekaOptionUtils.add(result, (String)MODEL_FILE, (File)this.getModelFile());
        WekaOptionUtils.add(result, (String[])super.getOptions());
        return WekaOptionUtils.toArray(result);
    }

    public void setOptions(String[] options) throws Exception {
        this.setSeed(WekaOptionUtils.parse((String[])options, (String)SEED, (int)1));
        this.setEpochs(WekaOptionUtils.parse((String[])options, (String)EPOCHS, (int)500));
        this.setModelFile(new PlaceholderFile((File)WekaOptionUtils.parse((String[])options, (String)MODEL_FILE, (PlaceholderFile)new PlaceholderFile())));
        this.setPreTrained(Utils.getFlag((String)PRE_TRAINED, (String[])options));
        this.setLayer(WekaOptionUtils.parse((String[])options, (String)LAYER, (int)-1));
        super.setOptions(options);
    }

    public int getUnitsFinalLayer() {
        MultiLayerNetwork mln = this.m_model.getMultiLayerNetwork();
        Layer[] layers = mln.getLayers();
        int numLayers = layers.length;
        List ll = mln.getLayerWiseConfigurations().getConfs();
        int numAtts = ((FeedForwardLayer)((NeuralNetConfiguration)ll.get(ll.size() - 1)).getLayer()).getNIn();
        return numAtts;
    }

    public int getUnitsFromLayer(int layernum) {
        MultiLayerNetwork mln = this.m_model.getMultiLayerNetwork();
        Layer[] layers = mln.getLayers();
        int numLayers = layers.length;
        List ll = mln.getLayerWiseConfigurations().getConfs();
        if (layernum >= ll.size()) {
            return 1;
        }
        if (layernum == -1) {
            layernum = ll.size() - 1;
        }
        int numAtts = ((FeedForwardLayer)((NeuralNetConfiguration)ll.get(layernum)).getLayer()).getNIn();
        return numAtts;
    }

    public synchronized List<INDArray> getUnitScores(Instance i) throws Exception {
        MultiLayerNetwork mln = this.m_model.getMultiLayerNetwork();
        Instances t_insts = new Instances(i.dataset(), 0);
        t_insts.add(i);
        if (this.m_Filter != null) {
            t_insts = Filter.useFilter((Instances)t_insts, (Filter)this.m_Filter);
        }
        DataSet ds = (DataSet)this.m_Iterator.getIterator(t_insts, this.getSeed(), 1).next();
        List list_of_vals = mln.feedForward(ds.getFeatureMatrix());
        return list_of_vals;
    }

    protected Instances determineOutputFormat(Instances inputFormat) throws Exception {
        this.loadModel();
        ArrayList<Attribute> atts = new ArrayList<Attribute>();
        int numAtts = this.getUnitsFromLayer(this.m_layer);
        String prefix = "unit";
        for (int i = 0; i < numAtts; ++i) {
            atts.add(new Attribute(prefix + "_" + (i + 1)));
        }
        atts.add(new Attribute("Class"));
        Instances result = new Instances(prefix, atts, 0);
        result.setClassIndex(result.numAttributes() - 1);
        return result;
    }

    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAll();
        result.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        result.enable(Capabilities.Capability.DATE_ATTRIBUTES);
        result.enable(Capabilities.Capability.MISSING_VALUES);
        result.enable(Capabilities.Capability.STRING_ATTRIBUTES);
        result.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
        result.enable(Capabilities.Capability.NUMERIC_CLASS);
        result.enable(Capabilities.Capability.NOMINAL_CLASS);
        return result;
    }

    protected Instances process(Instances instances) throws Exception {
        Instances result = null;
        this.loadModel();
        if (!this.getPreTrained() && !this.isFirstBatchDone()) {
            MaxEpoch trainstop = new MaxEpoch();
            trainstop.setNumEpochs(this.getEpochs());
            this.m_model.setTrainStop((AbstractTrainStopCriterion)trainstop);
            this.m_model.buildClassifier(instances);
        }
        Instances header = this.getOutputFormat();
        result = new Instances(header, 0);
        for (int i = 0; i < instances.numInstances(); ++i) {
            double[] values = new double[header.numAttributes()];
            List<INDArray> list_of_vals = this.getUnitScores(instances.get(i));
            INDArray ia = this.m_layer < 0 ? list_of_vals.get(list_of_vals.size() - 2) : (this.m_layer >= list_of_vals.size() - 1 ? list_of_vals.get(list_of_vals.size() - 1) : list_of_vals.get(this.m_layer));
            for (int j = 0; j < ia.length(); ++j) {
                values[j] = ia.getDouble(j);
            }
            values[values.length - 1] = instances.get(i).classValue();
            result.add((Instance)new DenseInstance(1.0, values));
        }
        return result;
    }

    public String getRevision() {
        return RevisionUtils.extract((String)"$Revision: 10364 $");
    }

    public static void main(String[] args) {
        DL4JMultiLayerNetworkFilter.runFilter((Filter)new DL4JMultiLayerNetworkFilter(), (String[])args);
    }
}

