/*
 * Decompiled with CFR 0.152.
 */
package weka.filters.supervised.attribute;

import adams.core.io.PlaceholderFile;
import java.io.File;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.List;
import java.util.Random;
import java.util.Vector;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionMetadata;
import weka.core.RevisionUtils;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;
import weka.core.WekaOptionUtils;
import weka.dl4j.iterators.AbstractDataSetIterator;
import weka.dl4j.iterators.DefaultInstancesIterator;
import weka.filters.Filter;
import weka.filters.SimpleBatchFilter;
import weka.filters.SupervisedFilter;
import weka.filters.supervised.attribute.NominalToBinary;
import weka.filters.unsupervised.attribute.Normalize;
import weka.filters.unsupervised.attribute.ReplaceMissingValues;
import weka.filters.unsupervised.attribute.Standardize;

public class DLFilter
extends SimpleBatchFilter
implements SupervisedFilter,
TechnicalInformationHandler {
    static final long serialVersionUID = -3335106965521265631L;
    protected static String SEED = "seed";
    protected static String PRE_TRAINED = "pre-trained";
    protected static String MODEL_FILE = "model-file";
    protected Model m_model;
    protected boolean m_model_loaded = false;
    protected int m_seed = 1;
    protected PlaceholderFile m_modelFile = new PlaceholderFile();
    protected boolean m_preTrained = true;
    protected AbstractDataSetIterator m_iterator = new DefaultInstancesIterator();
    protected boolean m_standardizeInsteadOfNormalize = true;
    protected Filter m_normalize;
    protected double m_x1 = 1.0;
    protected double m_x0 = 0.0;

    protected boolean loadModel() {
        if (!this.m_model_loaded) {
            try {
                this.m_model = ModelSerializer.restoreMultiLayerNetwork((File)this.getModelFile());
                this.m_model_loaded = true;
            }
            catch (Exception e) {
                System.err.println(e.toString());
                return false;
            }
            return true;
        }
        return true;
    }

    public String globalInfo() {
        return "\n\n" + this.getTechnicalInformation().toString();
    }

    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation result = new TechnicalInformation(TechnicalInformation.Type.BOOK);
        return result;
    }

    public boolean getPreTrained() {
        return this.m_preTrained;
    }

    public int getSeed() {
        return this.m_seed;
    }

    public PlaceholderFile getModelFile() {
        return this.m_modelFile;
    }

    public void setPreTrained(boolean b) {
        this.m_preTrained = b;
    }

    public void setSeed(int seed) {
        this.m_seed = seed;
    }

    public void setModelFile(PlaceholderFile mf) {
        this.m_modelFile = mf;
    }

    public String modelFileTipText() {
        return "Model file to load";
    }

    public String seedTipText() {
        return "Seed";
    }

    public String preTrainedTipText() {
        return "Is the model already built?";
    }

    public Enumeration<Option> listOptions() {
        Vector result = new Vector();
        WekaOptionUtils.addOption(result, (String)this.modelFileTipText(), (String)("" + this.getModelFile()), (String)MODEL_FILE);
        WekaOptionUtils.addOption(result, (String)this.preTrainedTipText(), (String)("" + this.getPreTrained()), (String)PRE_TRAINED);
        WekaOptionUtils.addOption(result, (String)this.seedTipText(), (String)("" + this.getSeed()), (String)SEED);
        WekaOptionUtils.add(result, (Enumeration)super.listOptions());
        return WekaOptionUtils.toEnumeration(result);
    }

    public String[] getOptions() {
        ArrayList result = new ArrayList();
        WekaOptionUtils.add(result, (String)SEED, (int)this.getSeed());
        WekaOptionUtils.add(result, (String)PRE_TRAINED, (boolean)this.getPreTrained());
        WekaOptionUtils.add(result, (String)MODEL_FILE, (File)this.getModelFile());
        WekaOptionUtils.add(result, (String[])super.getOptions());
        return WekaOptionUtils.toArray(result);
    }

    public void setOptions(String[] options) throws Exception {
        this.setSeed(WekaOptionUtils.parse((String[])options, (String)SEED, (int)1));
        this.setModelFile(WekaOptionUtils.parse((String[])options, (String)MODEL_FILE, (PlaceholderFile)new PlaceholderFile()));
        this.setPreTrained(Utils.getFlag((String)PRE_TRAINED, (String[])options));
        super.setOptions(options);
    }

    protected Instances determineOutputFormat(Instances inputFormat) throws Exception {
        this.loadModel();
        ArrayList<Attribute> atts = new ArrayList<Attribute>();
        Layer[] layers = ((MultiLayerNetwork)this.m_model).getLayers();
        int numLayers = layers.length;
        List ll = ((MultiLayerNetwork)this.m_model).getLayerWiseConfigurations().getConfs();
        int numAtts = ((FeedForwardLayer)((NeuralNetConfiguration)ll.get(1)).getLayer()).getNIn();
        String prefix = "unit";
        for (int i = 0; i < numAtts; ++i) {
            atts.add(new Attribute(prefix + "_" + (i + 1)));
        }
        atts.add(new Attribute("Class"));
        Instances result = new Instances(prefix, atts, 0);
        result.setClassIndex(result.numAttributes() - 1);
        return result;
    }

    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAll();
        result.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        result.enable(Capabilities.Capability.DATE_ATTRIBUTES);
        result.enable(Capabilities.Capability.MISSING_VALUES);
        result.enable(Capabilities.Capability.NUMERIC_CLASS);
        return result;
    }

    @OptionMetadata(description="The dataset iterator to use", displayName="dataset iterator", commandLineParamName="iterator", commandLineParamSynopsis="-iterator <string>", displayOrder=6)
    public AbstractDataSetIterator getDataSetIterator() {
        return this.m_iterator;
    }

    protected Instances process(Instances instances) throws Exception {
        Instances result = null;
        this.loadModel();
        if (!this.getPreTrained()) {
            int index;
            Instances data = new Instances(instances);
            data.deleteWithMissingClass();
            if (data.numInstances() == 0 || data.numAttributes() < 2) {
                return instances;
            }
            ReplaceMissingValues m_replaceMissing = new ReplaceMissingValues();
            m_replaceMissing.setInputFormat(data);
            data = Filter.useFilter((Instances)data, (Filter)m_replaceMissing);
            double y0 = data.instance(0).classValue();
            for (index = 1; index < data.numInstances() && data.instance(index).classValue() == y0; ++index) {
            }
            if (index == data.numInstances()) {
                throw new Exception("All class values are the same. At least two class values should be different");
            }
            double y1 = data.instance(index).classValue();
            NominalToBinary m_nominalToBinary = new NominalToBinary();
            m_nominalToBinary.setInputFormat(data);
            data = Filter.useFilter((Instances)data, (Filter)m_nominalToBinary);
            if (this.m_standardizeInsteadOfNormalize) {
                this.m_normalize = new Standardize();
                this.m_normalize.setOptions(new String[]{"-unset-class-temporarily"});
            } else {
                this.m_normalize = new Normalize();
            }
            this.m_normalize.setInputFormat(data);
            data = Filter.useFilter((Instances)data, (Filter)this.m_normalize);
            double z0 = data.instance(0).classValue();
            double z1 = data.instance(index).classValue();
            this.m_x1 = (y0 - y1) / (z0 - z1);
            this.m_x0 = y0 - this.m_x1 * z0;
            Random rand = new Random(this.getSeed());
            data.randomize(rand);
            ((MultiLayerNetwork)this.m_model).init();
            int numEpochs = this.m_model.conf().getNumIterations();
            DataSetIterator iter = this.getDataSetIterator().getIterator(data, this.getSeed());
            for (int i = 0; i < numEpochs; ++i) {
                ((MultiLayerNetwork)this.m_model).fit(iter);
                System.err.println("*** Completed epoch {} ***" + i + 1);
                iter.reset();
            }
        }
        Instances header = this.getOutputFormat();
        result = new Instances(header, 0);
        for (int i = 0; i < instances.numInstances(); ++i) {
            double[] values = new double[header.numAttributes()];
            Instances t_insts = new Instances(instances, 0);
            t_insts.add(instances.get(i));
            DataSet ds = (DataSet)this.getDataSetIterator().getIterator(t_insts, this.getSeed(), 1).next();
            List list_of_vals = ((MultiLayerNetwork)this.m_model).feedForward(ds.getFeatureMatrix());
            INDArray ia = (INDArray)list_of_vals.get(list_of_vals.size() - 2);
            for (int j = 0; j < ia.length(); ++j) {
                values[j] = ia.getDouble(j);
            }
            values[values.length - 1] = instances.get(i).classValue();
            result.add((Instance)new DenseInstance(1.0, values));
        }
        return result;
    }

    public String getRevision() {
        return RevisionUtils.extract((String)"$Revision: 10364 $");
    }

    public static void main(String[] args) {
        DLFilter.runFilter((Filter)new DLFilter(), (String[])args);
    }
}

