package adams.ml.preprocessing.supervised;

import adams.core.ObjectCopyHelper;
import adams.data.spreadsheet.DataRow;
import adams.data.spreadsheet.HeaderRow;
import adams.data.spreadsheet.SpreadSheet;
import adams.ml.capabilities.Capabilities;
import adams.ml.capabilities.Capability;
import adams.ml.data.Dataset;
import adams.ml.data.DatasetUtils;
import adams.ml.data.DefaultDataset;
import adams.ml.preprocessing.AbstractColumnSubsetBatchFilter;
import com.github.waikatodatamining.matrix.algorithm.pls.AbstractPLS;
import com.github.waikatodatamining.matrix.algorithm.pls.SIMPLS;
import com.github.waikatodatamining.matrix.core.Matrix;

/* loaded from: input_file:adams/ml/preprocessing/supervised/PLS.class */
public class PLS extends AbstractColumnSubsetBatchFilter {
    private static final long serialVersionUID = 8479195394918205567L;
    protected AbstractPLS m_Algorithm;
    protected AbstractPLS m_ActualAlgorithm;
    protected SpreadSheet m_Loadings;

    public String globalInfo() {
        return "Applies the selected partial least squares (PLS) algorithm to the data.";
    }

    @Override // adams.ml.preprocessing.AbstractColumnSubsetFilter
    public void defineOptions() {
        super.defineOptions();
        this.m_OptionManager.add("algorithm", "algorithm", new SIMPLS());
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // adams.ml.preprocessing.AbstractColumnSubsetFilter, adams.ml.preprocessing.AbstractFilter
    public void reset() {
        super.reset();
        this.m_Loadings = null;
    }

    public void setAlgorithm(AbstractPLS abstractPLS) {
        this.m_Algorithm = abstractPLS;
        reset();
    }

    public AbstractPLS getAlgorithm() {
        return this.m_Algorithm;
    }

    public String algorithmTipText() {
        return "The algorithm to use.";
    }

    @Override // adams.ml.preprocessing.Filter, adams.ml.capabilities.CapabilitiesHandler
    public Capabilities getCapabilities() {
        Capabilities capabilities = new Capabilities(this);
        capabilities.enable(Capability.NUMERIC_ATTRIBUTE);
        capabilities.enable(Capability.NUMERIC_CLASS);
        return capabilities;
    }

    @Override // adams.ml.preprocessing.AbstractColumnSubsetBatchFilter
    protected void doInitFilter(Dataset dataset) throws Exception {
        Matrix numericToMatrixAlgo = DatasetUtils.numericToMatrixAlgo(dataset, this.m_DataColumns.toArray());
        Matrix numericToMatrixAlgo2 = DatasetUtils.numericToMatrixAlgo(dataset, this.m_ClassColumns.toArray());
        this.m_ActualAlgorithm = (AbstractPLS) ObjectCopyHelper.copyObject(this.m_Algorithm);
        String initialize = this.m_ActualAlgorithm.initialize(numericToMatrixAlgo, numericToMatrixAlgo2);
        if (initialize != null) {
            throw new Exception(initialize);
        }
        if (this.m_ActualAlgorithm.hasLoadings()) {
            this.m_Loadings = DatasetUtils.matrixAlgoToSpreadSheet(this.m_ActualAlgorithm.getLoadings(), "Loadings-");
        }
    }

    @Override // adams.ml.preprocessing.AbstractColumnSubsetBatchFilter
    protected Dataset initOutputFormat(Dataset dataset) throws Exception {
        DefaultDataset defaultDataset = new DefaultDataset();
        HeaderRow headerRow = defaultDataset.getHeaderRow();
        String str = this.m_Algorithm.getClass().getSimpleName().toUpperCase() + "_";
        for (int i = 0; i < this.m_Algorithm.getNumComponents(); i++) {
            headerRow.addCell(str + (i + 1)).setContentAsString(str + (i + 1));
        }
        appendHeader(dataset, headerRow, this.m_OtherColumns);
        appendHeader(dataset, headerRow, this.m_ClassColumns);
        return defaultDataset;
    }

    @Override // adams.ml.preprocessing.AbstractColumnSubsetBatchFilter
    protected Dataset doFilter(Dataset dataset) throws Exception {
        Matrix transform = this.m_ActualAlgorithm.transform(DatasetUtils.numericToMatrixAlgo(dataset, this.m_DataColumns.toArray()));
        Dataset mo18getClone = getOutputFormat().mo18getClone();
        String str = this.m_Algorithm.getClass().getSimpleName().toUpperCase() + "_";
        for (int i = 0; i < transform.numRows(); i++) {
            DataRow row = dataset.getRow(i);
            DataRow addRow = mo18getClone.addRow();
            for (int i2 = 0; i2 < transform.numColumns(); i2++) {
                addRow.addCell(str + (i2 + 1)).setContent(Double.valueOf(transform.get(i, i2)));
            }
            appendData(row, addRow, this.m_OtherColumns);
            appendData(row, addRow, this.m_ClassColumns);
        }
        return mo18getClone;
    }
}
