/*
 * Decompiled with CFR 0.152.
 */
package adams.data.conversion;

import adams.data.conversion.AbstractConversion;
import adams.data.spreadsheet.DataRow;
import adams.data.spreadsheet.SpreadSheet;
import adams.data.spreadsheet.SpreadSheetColumnRange;
import gnu.trove.set.hash.TIntHashSet;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;

public class SpreadSheetToDL4JDataSet
extends AbstractConversion {
    private static final long serialVersionUID = 1970704417619148081L;
    protected SpreadSheetColumnRange m_ClassColumns;

    public String globalInfo() {
        return "Converts a spreadsheet into a DL4J DataSet.\nAssumes only numeric cells and no missing values to be present.\nNominal columns/classes need to be binarized first.";
    }

    public void defineOptions() {
        super.defineOptions();
        this.m_OptionManager.add("class-columns", "classColumns", (Object)new SpreadSheetColumnRange("last"));
    }

    public void setClassColumns(SpreadSheetColumnRange value) {
        this.m_ClassColumns = value;
        this.reset();
    }

    public SpreadSheetColumnRange getClassColumns() {
        return this.m_ClassColumns;
    }

    public String classColumnsTipText() {
        return "The spreadsheet reader to use for loading the data before converting it into a DL4J DataSet.";
    }

    public Class accepts() {
        return SpreadSheet.class;
    }

    public Class generates() {
        return DataSet.class;
    }

    protected Object doConvert() throws Exception {
        SpreadSheet sheet = (SpreadSheet)this.m_Input;
        this.m_ClassColumns.setData((Object)sheet);
        int[] classes = this.m_ClassColumns.getIntIndices();
        TIntHashSet classesSet = new TIntHashSet(classes);
        INDArray data = Nd4j.ones((int)sheet.getRowCount(), (int)(sheet.getColumnCount() - classes.length));
        double[][] outcomes = new double[sheet.getRowCount()][classes.length];
        for (int i = 0; i < sheet.getRowCount(); ++i) {
            int j;
            double[] independent = new double[sheet.getColumnCount() - classes.length];
            int index = 0;
            DataRow current = sheet.getRow(i);
            for (j = 0; j < sheet.getColumnCount(); ++j) {
                if (classesSet.contains(j)) continue;
                independent[index++] = current.getCell(j).toDouble();
            }
            for (j = 0; j < classes.length; ++j) {
                outcomes[i][j] = current.getCell(classes[j]).toDouble();
            }
            data.putRow(i, Nd4j.create((double[])independent));
        }
        return new DataSet(data, Nd4j.create((double[][])outcomes));
    }
}

