/*
 * Decompiled with CFR 0.152.
 */
package adams.ml.preprocessing.supervised;

import adams.core.ObjectCopyHelper;
import adams.data.spreadsheet.DataRow;
import adams.data.spreadsheet.HeaderRow;
import adams.data.spreadsheet.Row;
import adams.data.spreadsheet.SpreadSheet;
import adams.ml.capabilities.Capabilities;
import adams.ml.capabilities.Capability;
import adams.ml.data.Dataset;
import adams.ml.data.DatasetUtils;
import adams.ml.data.DefaultDataset;
import adams.ml.preprocessing.AbstractColumnSubsetBatchFilter;
import com.github.waikatodatamining.matrix.algorithm.pls.AbstractPLS;
import com.github.waikatodatamining.matrix.algorithm.pls.SIMPLS;
import com.github.waikatodatamining.matrix.core.Matrix;

public class PLS
extends AbstractColumnSubsetBatchFilter {
    private static final long serialVersionUID = 8479195394918205567L;
    protected AbstractPLS m_Algorithm;
    protected AbstractPLS m_ActualAlgorithm;
    protected SpreadSheet m_Loadings;

    public String globalInfo() {
        return "Applies the selected partial least squares (PLS) algorithm to the data.";
    }

    @Override
    public void defineOptions() {
        super.defineOptions();
        this.m_OptionManager.add("algorithm", "algorithm", (Object)new SIMPLS());
    }

    @Override
    protected void reset() {
        super.reset();
        this.m_Loadings = null;
    }

    public void setAlgorithm(AbstractPLS value) {
        this.m_Algorithm = value;
        this.reset();
    }

    public AbstractPLS getAlgorithm() {
        return this.m_Algorithm;
    }

    public String algorithmTipText() {
        return "The algorithm to use.";
    }

    @Override
    public Capabilities getCapabilities() {
        Capabilities result = new Capabilities(this);
        result.enable(Capability.NUMERIC_ATTRIBUTE);
        result.enable(Capability.NUMERIC_CLASS);
        return result;
    }

    @Override
    protected void doInitFilter(Dataset data) throws Exception {
        Matrix predictors = DatasetUtils.numericToMatrixAlgo(data, this.m_DataColumns.toArray());
        Matrix response = DatasetUtils.numericToMatrixAlgo(data, this.m_ClassColumns.toArray());
        this.m_ActualAlgorithm = (AbstractPLS)ObjectCopyHelper.copyObject((Object)this.m_Algorithm);
        String msg = this.m_ActualAlgorithm.initialize(predictors, response);
        if (msg != null) {
            throw new Exception(msg);
        }
        if (this.m_ActualAlgorithm.hasLoadings()) {
            this.m_Loadings = DatasetUtils.matrixAlgoToSpreadSheet(this.m_ActualAlgorithm.getLoadings(), "Loadings-");
        }
    }

    @Override
    protected Dataset initOutputFormat(Dataset data) throws Exception {
        DefaultDataset result = new DefaultDataset();
        HeaderRow row = result.getHeaderRow();
        String prefix = this.m_Algorithm.getClass().getSimpleName().toUpperCase() + "_";
        for (int i = 0; i < this.m_Algorithm.getNumComponents(); ++i) {
            row.addCell(prefix + (i + 1)).setContentAsString(prefix + (i + 1));
        }
        this.appendHeader(data, (Row)row, this.m_OtherColumns);
        this.appendHeader(data, (Row)row, this.m_ClassColumns);
        return result;
    }

    @Override
    protected Dataset doFilter(Dataset data) throws Exception {
        Matrix predictors = DatasetUtils.numericToMatrixAlgo(data, this.m_DataColumns.toArray());
        Matrix transformed = this.m_ActualAlgorithm.transform(predictors);
        Dataset result = this.getOutputFormat().getClone();
        String prefix = this.m_Algorithm.getClass().getSimpleName().toUpperCase() + "_";
        for (int y = 0; y < transformed.numRows(); ++y) {
            DataRow rowIn = data.getRow(y);
            DataRow rowOut = result.addRow();
            for (int x = 0; x < transformed.numColumns(); ++x) {
                rowOut.addCell(prefix + (x + 1)).setContent(Double.valueOf(transformed.get(y, x)));
            }
            this.appendData((Row)rowIn, (Row)rowOut, this.m_OtherColumns);
            this.appendData((Row)rowIn, (Row)rowOut, this.m_ClassColumns);
        }
        return result;
    }
}

