/*
 * Decompiled with CFR 0.152.
 */
package weka.filters.supervised.attribute;

import adams.core.ObjectCopyHelper;
import adams.core.base.BaseObject;
import adams.core.base.BaseRegExp;
import adams.core.option.OptionHandler;
import adams.data.instancesanalysis.pls.AbstractPLS;
import adams.data.instancesanalysis.pls.PLS1;
import gnu.trove.list.TIntList;
import gnu.trove.list.array.TIntArrayList;
import java.lang.invoke.CallSite;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.Map;
import java.util.Vector;
import weka.core.Attribute;
import weka.core.DenseInstance;
import weka.core.GenericPLSMatrixAccess;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.RevisionUtils;
import weka.core.Utils;
import weka.core.WekaOptionUtils;
import weka.core.matrix.Matrix;
import weka.filters.Filter;
import weka.filters.SimpleBatchFilter;
import weka.filters.SupervisedFilter;
import weka.filters.unsupervised.attribute.Remove;

public class MultiPLS
extends SimpleBatchFilter
implements SupervisedFilter,
GenericPLSMatrixAccess {
    static final long serialVersionUID = -3335106965521265631L;
    protected BaseRegExp m_XRegExp = this.getDefaultXRegExp();
    protected BaseRegExp m_YRegExp = this.getDefaultYRegExp();
    protected AbstractPLS m_Algorithm = this.getDefaultAlgorithm();
    protected boolean m_DropNonClassYs;
    protected TIntList m_XIndices = new TIntArrayList();
    protected TIntList m_YIndices = new TIntArrayList();
    protected TIntList m_OtherIndices = new TIntArrayList();
    protected Map<String, AbstractPLS> m_PLS = new HashMap<String, AbstractPLS>();
    protected String[] m_MatrixNames = new String[0];

    public String globalInfo() {
        return "For each Y that gets identified by the regular expression for Y attributes, the specified PLS (partial least squares) algorithm gets applied to the X attributes identified by the corresponding regular expression.";
    }

    public Enumeration<Option> listOptions() {
        Vector result = new Vector();
        WekaOptionUtils.addOption(result, this.XRegExpTipText(), this.getDefaultXRegExp().getValue(), "x-regexp");
        WekaOptionUtils.addOption(result, this.YRegExpTipText(), this.getDefaultYRegExp().getValue(), "y-regexp");
        WekaOptionUtils.addOption(result, this.algorithmTipText(), this.getDefaultAlgorithm().getClass().getName(), "algorithm");
        WekaOptionUtils.addOption(result, this.dropNonClassYsTipText(), "off", "drop-non-class-ys");
        WekaOptionUtils.add(result, super.listOptions());
        return result.elements();
    }

    public String[] getOptions() {
        ArrayList<String> result = new ArrayList<String>();
        WekaOptionUtils.add(result, "x-regexp", (BaseObject)this.getXRegExp());
        WekaOptionUtils.add(result, "y-regexp", (BaseObject)this.getYRegExp());
        WekaOptionUtils.add(result, "algorithm", (OptionHandler)this.getAlgorithm());
        WekaOptionUtils.add(result, "drop-non-class-ys", this.getDropNonClassYs());
        Collections.addAll(result, super.getOptions());
        return result.toArray(new String[result.size()]);
    }

    public void setOptions(String[] options) throws Exception {
        this.setXRegExp((BaseRegExp)WekaOptionUtils.parse(options, "x-regexp", (BaseObject)this.getDefaultXRegExp()));
        this.setYRegExp((BaseRegExp)WekaOptionUtils.parse(options, "y-regexp", (BaseObject)this.getDefaultYRegExp()));
        this.setAlgorithm((AbstractPLS)WekaOptionUtils.parse(options, "algorithm", (OptionHandler)this.getDefaultAlgorithm()));
        this.setDropNonClassYs(Utils.getFlag((String)"drop-non-class-ys", (String[])options));
        super.setOptions(options);
        Utils.checkForRemainingOptions((String[])options);
    }

    protected BaseRegExp getDefaultXRegExp() {
        return new BaseRegExp("amplitude-.*");
    }

    public void setXRegExp(BaseRegExp value) {
        this.m_XRegExp = value;
    }

    public BaseRegExp getXRegExp() {
        return this.m_XRegExp;
    }

    public String XRegExpTipText() {
        return "The regular expression to identify the X attributes for PLS.";
    }

    protected BaseRegExp getDefaultYRegExp() {
        return new BaseRegExp("A_.*");
    }

    public void setYRegExp(BaseRegExp value) {
        this.m_YRegExp = value;
    }

    public BaseRegExp getYRegExp() {
        return this.m_YRegExp;
    }

    public String YRegExpTipText() {
        return "The regular expression to identify the Y attributes for PLS.";
    }

    protected AbstractPLS getDefaultAlgorithm() {
        return new PLS1();
    }

    public void setAlgorithm(AbstractPLS value) {
        this.m_Algorithm = value;
    }

    public AbstractPLS getAlgorithm() {
        return this.m_Algorithm;
    }

    public String algorithmTipText() {
        return "The PLS algorithm to apply.";
    }

    public void setDropNonClassYs(boolean value) {
        this.m_DropNonClassYs = value;
    }

    public boolean getDropNonClassYs() {
        return this.m_DropNonClassYs;
    }

    public String dropNonClassYsTipText() {
        return "If enabled, Y attributes that aren't the class attribute are removed from the output.";
    }

    protected Instances determineOutputFormat(Instances inputFormat) throws Exception {
        int n;
        int i;
        this.m_XIndices.clear();
        this.m_YIndices.clear();
        this.m_OtherIndices.clear();
        this.m_PLS.clear();
        String className = null;
        for (i = 0; i < inputFormat.numAttributes(); ++i) {
            Attribute att = inputFormat.attribute(i);
            if (this.m_XRegExp.isMatch(att.name()) && att.isNumeric()) {
                this.m_XIndices.add(i);
            } else if (this.m_YRegExp.isMatch(att.name()) && att.isNumeric()) {
                this.m_YIndices.add(i);
            } else if (i != inputFormat.classIndex()) {
                this.m_OtherIndices.add(i);
            }
            if (i != inputFormat.classIndex()) continue;
            className = att.name();
        }
        if (this.getDebug()) {
            System.out.println("X: " + this.m_XIndices);
            System.out.println("Y: " + this.m_YIndices);
            System.out.println("Other: " + this.m_OtherIndices);
        }
        ArrayList<Attribute> atts = new ArrayList<Attribute>();
        for (i = 0; i < this.m_OtherIndices.size(); ++i) {
            atts.add((Attribute)inputFormat.attribute(this.m_OtherIndices.get(i)).copy());
        }
        for (i = 0; i < this.m_YIndices.size(); ++i) {
            for (n = 0; n < this.m_Algorithm.getNumComponents(); ++n) {
                String name = inputFormat.attribute(this.m_YIndices.get(i)).name() + "-" + this.m_Algorithm.getClass().getSimpleName() + "_" + (n + 1);
                atts.add(new Attribute(name));
            }
        }
        for (i = 0; i < this.m_YIndices.size(); ++i) {
            if (this.m_DropNonClassYs) {
                if (this.m_YIndices.get(i) != inputFormat.classIndex()) continue;
                atts.add((Attribute)inputFormat.attribute(this.m_YIndices.get(i)).copy());
                continue;
            }
            atts.add((Attribute)inputFormat.attribute(this.m_YIndices.get(i)).copy());
        }
        if (inputFormat.classIndex() > -1 && !this.m_YIndices.contains(inputFormat.classIndex())) {
            atts.add((Attribute)inputFormat.classAttribute().copy());
        }
        Instances result = new Instances(inputFormat.relationName(), atts, 0);
        if (className != null) {
            result.setClassIndex(result.attribute(className).index());
        }
        for (i = 0; i < this.m_YIndices.size(); ++i) {
            this.m_PLS.put(inputFormat.attribute(this.m_YIndices.get(i)).name(), (AbstractPLS)ObjectCopyHelper.copyObject((Object)this.m_Algorithm));
        }
        ArrayList<CallSite> names = new ArrayList<CallSite>();
        for (i = 0; i < this.m_YIndices.size(); ++i) {
            for (n = 0; n < this.m_Algorithm.getMatrixNames().length; ++n) {
                names.add((CallSite)((Object)(inputFormat.attribute(this.m_YIndices.get(i)).name() + "-" + this.m_Algorithm.getMatrixNames()[n])));
            }
        }
        this.m_MatrixNames = names.toArray(new String[names.size()]);
        if (this.getDebug()) {
            System.out.println("Matrix names: " + Utils.arrayToString((Object)this.m_MatrixNames));
        }
        return result;
    }

    protected Instances process(Instances instances) throws Exception {
        Instances subset;
        String name;
        int i;
        TIntArrayList indices = new TIntArrayList();
        Remove remove = new Remove();
        HashMap<String, Instances> plsed = new HashMap<String, Instances>();
        int classIndex = instances.classIndex();
        instances.setClassIndex(-1);
        for (i = 0; i < this.m_YIndices.size(); ++i) {
            name = instances.attribute(this.m_YIndices.get(i)).name();
            if (this.getDebug()) {
                if (!this.isFirstBatchDone()) {
                    System.out.println("Initializing PLS #" + (i + 1) + ": " + name);
                } else {
                    System.out.println("Applying PLS #" + (i + 1) + ": " + name);
                }
            }
            indices.clear();
            indices.add(this.m_XIndices.toArray());
            indices.add(this.m_YIndices.get(i));
            remove.setAttributeIndicesArray(indices.toArray());
            remove.setInvertSelection(true);
            remove.setInputFormat(instances);
            subset = Filter.useFilter((Instances)instances, (Filter)remove);
            subset.setClassIndex(subset.numAttributes() - 1);
            if (!this.isFirstBatchDone()) {
                this.m_PLS.get(name).determineOutputFormat(subset);
            }
            plsed.put(name, this.m_PLS.get(name).transform(subset));
        }
        Instances result = this.getOutputFormat();
        for (int n = 0; n < instances.numInstances(); ++n) {
            int m;
            double[] values = new double[result.numAttributes()];
            int index = 0;
            for (m = 0; m < this.m_OtherIndices.size(); ++m) {
                if (m == classIndex) continue;
                values[index] = instances.instance(n).isMissing(this.m_OtherIndices.get(m)) ? Utils.missingValue() : (result.attribute(index).isString() ? (double)result.attribute(index).addStringValue(instances.instance(n).stringValue(this.m_OtherIndices.get(m))) : (result.attribute(index).isRelationValued() ? (double)result.attribute(index).addRelation(instances.instance(n).relationalValue(this.m_OtherIndices.get(m))) : instances.instance(n).value(this.m_OtherIndices.get(m))));
                ++index;
            }
            for (i = 0; i < this.m_YIndices.size(); ++i) {
                name = instances.attribute(this.m_YIndices.get(i)).name();
                subset = (Instances)plsed.get(name);
                for (m = 0; m < subset.numAttributes(); ++m) {
                    if (m == subset.classIndex()) continue;
                    values[index] = subset.instance(n).value(m);
                    ++index;
                }
            }
            for (i = 0; i < this.m_YIndices.size(); ++i) {
                if (this.m_DropNonClassYs) {
                    if (this.m_YIndices.get(i) != classIndex) continue;
                    values[index] = instances.instance(n).value(this.m_YIndices.get(i));
                    ++index;
                    continue;
                }
                values[index] = instances.instance(n).value(this.m_YIndices.get(i));
                ++index;
            }
            if (classIndex > -1 && !this.m_YIndices.contains(classIndex)) {
                values[index] = instances.instance(n).isMissing(classIndex) ? Utils.missingValue() : (result.attribute(index).isString() ? (double)result.attribute(index).addStringValue(instances.instance(n).stringValue(classIndex)) : (result.attribute(index).isRelationValued() ? (double)result.attribute(index).addRelation(instances.instance(n).relationalValue(classIndex)) : instances.instance(n).value(classIndex)));
            }
            result.add((Instance)new DenseInstance(instances.instance(n).weight(), values));
        }
        return result;
    }

    @Override
    public String[] getMatrixNames() {
        return this.m_MatrixNames;
    }

    protected String extractPLSKey(String matrixName) {
        for (int i = 0; i < this.m_Algorithm.getMatrixNames().length; ++i) {
            String search = "-" + this.m_Algorithm.getMatrixNames()[i];
            if (!matrixName.endsWith(search)) continue;
            return matrixName.substring(0, matrixName.length() - search.length());
        }
        return matrixName;
    }

    @Override
    public Matrix getMatrix(String name) {
        String key = this.extractPLSKey(name);
        if (this.m_PLS.containsKey(key)) {
            return this.m_PLS.get(key).getMatrix(name.substring(key.length()));
        }
        return null;
    }

    @Override
    public boolean hasLoadings() {
        return false;
    }

    @Override
    public Matrix getLoadings() {
        return null;
    }

    public String getRevision() {
        return RevisionUtils.extract((String)"$Revision: 10364 $");
    }

    public static void main(String[] args) {
        MultiPLS.runFilter((Filter)new MultiPLS(), (String[])args);
    }
}

