/*
 * Decompiled with CFR 0.152.
 */
package adams.data.conversion;

import adams.data.conversion.AbstractConversion;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import weka.core.Instance;
import weka.core.Instances;

public class WekaInstancesToDL4JDataSet
extends AbstractConversion {
    private static final long serialVersionUID = -7278857064645982416L;

    public String globalInfo() {
        return "Converts a Weka Instances object to a DL4J DataSet.\nAssumes missing values to be imputed and nominal attributes to be binarized.";
    }

    public Class accepts() {
        return Instances.class;
    }

    public Class generates() {
        return DataSet.class;
    }

    protected Object doConvert() throws Exception {
        Instances insts = (Instances)this.m_Input;
        INDArray data = Nd4j.ones((int)insts.numInstances(), (int)(insts.numAttributes() - 1));
        double[][] outcomes = new double[insts.numInstances()][insts.classAttribute().numValues() == 0 ? 1 : insts.classAttribute().numValues()];
        for (int i = 0; i < insts.numInstances(); ++i) {
            double[] independent = new double[insts.numAttributes() - 1];
            int index = 0;
            Instance current = insts.instance(i);
            for (int j = 0; j < insts.numAttributes(); ++j) {
                if (j != insts.classIndex()) {
                    independent[index++] = current.value(j);
                    continue;
                }
                if (insts.numClasses() > 1) {
                    outcomes[i][(int)current.classValue()] = 1.0;
                    continue;
                }
                outcomes[i][0] = current.classValue();
            }
            data.putRow(i, Nd4j.create((double[])independent));
        }
        return new DataSet(data, Nd4j.create((double[][])outcomes));
    }
}

