package weka.classifiers.functions;

import java.util.List;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import weka.core.Instance;
import weka.core.Instances;

/* loaded from: input_file:weka/classifiers/functions/Dl4jMlpClassifierExtended.class */
public class Dl4jMlpClassifierExtended extends Dl4jMlpClassifier {
    public int getUnitsFinalLayer() {
        int length = this.m_model.getLayers().length;
        List confs = this.m_model.getLayerWiseConfigurations().getConfs();
        return ((NeuralNetConfiguration) confs.get(confs.size() - 1)).getLayer().getNIn();
    }

    public synchronized List<INDArray> getUnitScores(Instance instance) throws Exception {
        this.m_normalize.input(instance);
        Instance output = this.m_normalize.output();
        Instances instances = new Instances(instance.dataset(), 0);
        instances.add(output);
        return this.m_model.feedForward(((DataSet) getDataSetIterator().getIterator(instances, getSeed(), 1).next()).getFeatureMatrix());
    }
}
