/*
 * Decompiled with CFR 0.152.
 */
package weka.filters.supervised.attribute;

import adams.core.io.PlaceholderFile;
import java.io.File;
import java.io.FileInputStream;
import java.io.ObjectInputStream;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.List;
import java.util.Vector;
import org.nd4j.linalg.api.ndarray.INDArray;
import weka.classifiers.functions.Dl4jMlpClassifierExtended;
import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.RevisionUtils;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;
import weka.core.WekaOptionUtils;
import weka.filters.Filter;
import weka.filters.SimpleBatchFilter;
import weka.filters.SupervisedFilter;

public class DLFilterExtended
extends SimpleBatchFilter
implements SupervisedFilter,
TechnicalInformationHandler {
    static final long serialVersionUID = -3335106965521265631L;
    protected static String SEED = "seed";
    protected static String PRE_TRAINED = "pre-trained";
    protected static String MODEL_FILE = "model-file";
    protected static String LOG_FILE = "log-file";
    protected static String EPOCHS = "epochs";
    protected Dl4jMlpClassifierExtended m_model;
    protected boolean m_model_loaded = false;
    protected int m_epochs = 500;
    protected int m_seed = 1;
    protected PlaceholderFile m_modelFile = new PlaceholderFile();
    protected PlaceholderFile m_logFile = new PlaceholderFile();
    protected boolean m_preTrained = true;

    protected boolean loadModel() {
        if (!this.m_model_loaded) {
            try {
                FileInputStream inputFileStream = new FileInputStream((File)this.getModelFile());
                ObjectInputStream objectInputStream = new ObjectInputStream(inputFileStream);
                this.m_model = (Dl4jMlpClassifierExtended)((Object)objectInputStream.readObject());
                this.m_model.setLogFile(new File(this.getLogFile().getAbsoluteFile().toString()));
                objectInputStream.close();
                inputFileStream.close();
                this.m_model_loaded = true;
            }
            catch (Exception e) {
                System.err.println(e.toString());
                return false;
            }
            return true;
        }
        return true;
    }

    public String globalInfo() {
        return "\n\n" + this.getTechnicalInformation().toString();
    }

    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation result = new TechnicalInformation(TechnicalInformation.Type.BOOK);
        return result;
    }

    public boolean getPreTrained() {
        return this.m_preTrained;
    }

    public int getSeed() {
        return this.m_seed;
    }

    public PlaceholderFile getModelFile() {
        return this.m_modelFile;
    }

    public PlaceholderFile getLogFile() {
        return this.m_logFile;
    }

    public int getEpochs() {
        return this.m_epochs;
    }

    public void setPreTrained(boolean b) {
        this.m_preTrained = b;
    }

    public void setSeed(int seed) {
        this.m_seed = seed;
    }

    public void setModelFile(PlaceholderFile mf) {
        this.m_modelFile = mf;
    }

    public void setLogFile(PlaceholderFile mf) {
        this.m_logFile = mf;
    }

    public void setEpochs(int e) {
        this.m_epochs = e;
    }

    public String modelFileTipText() {
        return "Model file to load";
    }

    public String logFileTipText() {
        return "Log file for Weka Classifier";
    }

    public String seedTipText() {
        return "Seed";
    }

    public String preTrainedTipText() {
        return "Is the model already built?";
    }

    public String epochsTipText() {
        return "epochs to train with if not pre-trained";
    }

    public Enumeration<Option> listOptions() {
        Vector result = new Vector();
        WekaOptionUtils.addOption(result, (String)this.logFileTipText(), (String)("" + this.getLogFile()), (String)LOG_FILE);
        WekaOptionUtils.addOption(result, (String)this.modelFileTipText(), (String)("" + this.getModelFile()), (String)MODEL_FILE);
        WekaOptionUtils.addOption(result, (String)this.preTrainedTipText(), (String)("" + this.getPreTrained()), (String)PRE_TRAINED);
        WekaOptionUtils.addOption(result, (String)this.seedTipText(), (String)("" + this.getSeed()), (String)SEED);
        WekaOptionUtils.addOption(result, (String)this.epochsTipText(), (String)("" + this.getEpochs()), (String)EPOCHS);
        WekaOptionUtils.add(result, (Enumeration)super.listOptions());
        return WekaOptionUtils.toEnumeration(result);
    }

    public String[] getOptions() {
        ArrayList result = new ArrayList();
        WekaOptionUtils.add(result, (String)SEED, (int)this.getSeed());
        WekaOptionUtils.add(result, (String)EPOCHS, (int)this.getEpochs());
        WekaOptionUtils.add(result, (String)PRE_TRAINED, (boolean)this.getPreTrained());
        WekaOptionUtils.add(result, (String)MODEL_FILE, (File)this.getModelFile());
        WekaOptionUtils.add(result, (String)LOG_FILE, (File)this.getLogFile());
        WekaOptionUtils.add(result, (String[])super.getOptions());
        return WekaOptionUtils.toArray(result);
    }

    public void setOptions(String[] options) throws Exception {
        this.setSeed(WekaOptionUtils.parse((String[])options, (String)SEED, (int)1));
        this.setEpochs(WekaOptionUtils.parse((String[])options, (String)EPOCHS, (int)500));
        this.setModelFile(new PlaceholderFile((File)WekaOptionUtils.parse((String[])options, (String)MODEL_FILE, (PlaceholderFile)new PlaceholderFile())));
        this.setLogFile(new PlaceholderFile((File)WekaOptionUtils.parse((String[])options, (String)LOG_FILE, (PlaceholderFile)new PlaceholderFile())));
        this.setPreTrained(Utils.getFlag((String)PRE_TRAINED, (String[])options));
        super.setOptions(options);
    }

    protected Instances determineOutputFormat(Instances inputFormat) throws Exception {
        this.loadModel();
        ArrayList<Attribute> atts = new ArrayList<Attribute>();
        int numAtts = this.m_model.getUnitsFinalLayer();
        String prefix = "unit";
        for (int i = 0; i < numAtts; ++i) {
            atts.add(new Attribute(prefix + "_" + (i + 1)));
        }
        atts.add(new Attribute("Class"));
        Instances result = new Instances(prefix, atts, 0);
        result.setClassIndex(result.numAttributes() - 1);
        return result;
    }

    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAll();
        result.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        result.enable(Capabilities.Capability.DATE_ATTRIBUTES);
        result.enable(Capabilities.Capability.MISSING_VALUES);
        result.enable(Capabilities.Capability.NUMERIC_CLASS);
        return result;
    }

    protected Instances process(Instances instances) throws Exception {
        Instances result = null;
        this.loadModel();
        if (!this.getPreTrained() && !this.isFirstBatchDone()) {
            this.m_model.setNumEpochs(this.getEpochs());
            this.m_model.buildClassifier(instances);
        }
        Instances header = this.getOutputFormat();
        result = new Instances(header, 0);
        for (int i = 0; i < instances.numInstances(); ++i) {
            double[] values = new double[header.numAttributes()];
            List<INDArray> list_of_vals = this.m_model.getUnitScores(instances.get(i));
            INDArray ia = list_of_vals.get(list_of_vals.size() - 2);
            for (int j = 0; j < ia.length(); ++j) {
                values[j] = ia.getDouble(j);
            }
            values[values.length - 1] = instances.get(i).classValue();
            result.add((Instance)new DenseInstance(1.0, values));
        }
        return result;
    }

    public String getRevision() {
        return RevisionUtils.extract((String)"$Revision: 10364 $");
    }

    public static void main(String[] args) {
        DLFilterExtended.runFilter((Filter)new DLFilterExtended(), (String[])args);
    }
}

