/*
 * Decompiled with CFR 0.152.
 */
package weka.filters;

import adams.core.option.OptionUtils;
import adams.data.weka.columnfinder.Class;
import adams.data.weka.columnfinder.ColumnFinder;
import adams.data.weka.datasetsplitter.ColumnSplitter;
import adams.flow.transformer.wekadatasetsmerge.Simple;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Enumeration;
import java.util.Vector;
import nz.ac.waikato.cms.adams.multiway.algorithm.api.AbstractAlgorithm;
import nz.ac.waikato.cms.adams.multiway.algorithm.api.Filter;
import nz.ac.waikato.cms.adams.multiway.algorithm.api.SupervisedAlgorithm;
import nz.ac.waikato.cms.adams.multiway.algorithm.api.UnsupervisedAlgorithm;
import nz.ac.waikato.cms.adams.multiway.algorithm.twoway.PLS2;
import nz.ac.waikato.cms.adams.multiway.data.tensor.Tensor;
import weka.core.Attribute;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.Utils;
import weka.core.WekaOptionUtils;
import weka.filters.SimpleBatchFilter;

public class MultiwayFilter
extends SimpleBatchFilter {
    private static final long serialVersionUID = -5490573675185624414L;
    protected AbstractAlgorithm m_Filter = this.getDefaultFilter();
    protected ColumnSplitter m_ClassSplitter = null;
    protected Instances m_FilteredTemplate = null;
    protected Simple m_ClassMerger = null;

    public Enumeration<Option> listOptions() {
        Vector result = new Vector();
        WekaOptionUtils.addOption(result, (String)this.filterTipText(), (String)OptionUtils.getCommandLine((Object)this.getDefaultFilter()), (String)"filter");
        WekaOptionUtils.add(result, (Enumeration)super.listOptions());
        return result.elements();
    }

    public String[] getOptions() {
        ArrayList result = new ArrayList();
        WekaOptionUtils.add(result, (String)"filter", (String)OptionUtils.getCommandLine((Object)this.getFilter()));
        Collections.addAll(result, super.getOptions());
        return result.toArray(new String[0]);
    }

    public void setOptions(String[] options) throws Exception {
        String filterCmd = OptionUtils.removeOption((String[])options, (String)"-filter");
        if (filterCmd != null) {
            this.setFilter((AbstractAlgorithm)OptionUtils.forAnyCommandLine(AbstractAlgorithm.class, (String)filterCmd));
        } else {
            this.setFilter(this.getDefaultFilter());
        }
        super.setOptions(options);
        Utils.checkForRemainingOptions((String[])options);
    }

    public AbstractAlgorithm getDefaultFilter() {
        return new PLS2();
    }

    public void setFilter(AbstractAlgorithm value) {
        if (!(value instanceof Filter)) {
            return;
        }
        if (!(value instanceof SupervisedAlgorithm) && !(value instanceof UnsupervisedAlgorithm)) {
            return;
        }
        this.m_Filter = value;
    }

    public AbstractAlgorithm getFilter() {
        return this.m_Filter;
    }

    public String filterTipText() {
        return "The multiway filtering algorithm to use.";
    }

    public String globalInfo() {
        return "Wrapper treating a multiway filter as a WEKA filter.";
    }

    protected Filter getFilterAsFilter() {
        return (Filter)this.getFilter();
    }

    public boolean allowAccessToFullInputFormat() {
        return true;
    }

    protected void reset() {
        String[] currentOptions = this.getOptions();
        try {
            this.setOptions(currentOptions);
        }
        catch (Exception e) {
            e.printStackTrace();
        }
    }

    protected Instances determineOutputFormat(Instances inputFormat) throws Exception {
        if (this.isFirstBatchDone()) {
            return this.getOutputFormat();
        }
        if (this.m_Filter instanceof UnsupervisedAlgorithm) {
            Tensor input = this.instancesToTensor(inputFormat);
            ((UnsupervisedAlgorithm)this.m_Filter).build(input);
            return this.emptyInstancesForTensor((int)this.getFilterAsFilter().filter(input).size(1));
        }
        this.m_ClassSplitter = new ColumnSplitter();
        this.m_ClassSplitter.setColumnFinder((ColumnFinder)new Class());
        Instances[] split = this.m_ClassSplitter.split(inputFormat);
        Tensor x = this.instancesToTensor(split[1]);
        Tensor y = this.instancesToTensor(split[0]);
        ((SupervisedAlgorithm)this.m_Filter).build(x, y);
        Instances filtered = this.emptyInstancesForTensor((int)this.getFilterAsFilter().filter(x).size(1));
        Instances classSet = new Instances(split[0], 0);
        return this.remergeClassAttribute(filtered, classSet);
    }

    protected Instances process(Instances instances) throws Exception {
        if (this.m_Filter instanceof UnsupervisedAlgorithm) {
            Tensor input = this.instancesToTensor(instances);
            return this.tensorToInstances(this.getFilterAsFilter().filter(input));
        }
        Instances[] split = this.m_ClassSplitter.split(instances);
        Tensor input = this.instancesToTensor(split[1]);
        Instances filtered = this.tensorToInstances(this.getFilterAsFilter().filter(input));
        return this.remergeClassAttribute(filtered, split[0]);
    }

    protected Instances remergeClassAttribute(Instances filtered, Instances classSet) {
        if (this.m_ClassMerger == null) {
            this.m_ClassMerger = new Simple();
        }
        return this.m_ClassMerger.merge(new Instances[]{filtered, classSet});
    }

    protected Tensor instancesToTensor(Instances instances) {
        double[][] data = new double[instances.numInstances()][instances.numAttributes()];
        for (int instanceIndex = 0; instanceIndex < instances.numInstances(); ++instanceIndex) {
            Instance instance = instances.get(instanceIndex);
            for (int attributeIndex = 0; attributeIndex < instances.numAttributes(); ++attributeIndex) {
                data[instanceIndex][attributeIndex] = instance.value(attributeIndex);
            }
        }
        return Tensor.create((double[][])data);
    }

    protected Instances tensorToInstances(Tensor tensor) {
        Instances output = this.emptyInstancesForTensor((int)tensor.size(1));
        double[][] data = tensor.toArray2d();
        int i = 0;
        while ((long)i < tensor.size(0)) {
            double[] row = data[i];
            DenseInstance instance = new DenseInstance(1.0, row);
            output.add((Instance)instance);
            ++i;
        }
        return output;
    }

    protected Instances emptyInstancesForTensor(int size) {
        if (this.m_FilteredTemplate == null) {
            ArrayList<Attribute> attributes = new ArrayList<Attribute>();
            for (int i = 1; i <= size; ++i) {
                Attribute attribute = new Attribute(this.m_Filter.getClass().getSimpleName() + i);
                attributes.add(attribute);
            }
            this.m_FilteredTemplate = new Instances("output", attributes, 0);
        }
        return new Instances(this.m_FilteredTemplate, 0);
    }
}

