/*
 * Decompiled with CFR 0.152.
 */
package mulan.classifier.neural.model;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import mulan.classifier.neural.model.ActivationFunction;
import mulan.classifier.neural.model.ActivationLinear;
import mulan.classifier.neural.model.NeuralNet;
import mulan.classifier.neural.model.Neuron;
import mulan.core.ArgumentNullException;

public class BasicNeuralNet
implements NeuralNet,
Serializable {
    private static final long serialVersionUID = -8944873770650464701L;
    private final List<List<Neuron>> layers;
    private double[] currentNetOutput;
    private final int netInputDim;
    private final int netOutputDim;

    public BasicNeuralNet(int[] netTopology, double biasInput, Class<? extends ActivationFunction> activationFunction, Random random) {
        if (netTopology == null || netTopology.length < 2) {
            throw new IllegalArgumentException("The topology for neural network is not specified or is invalid. Please provide correct topology for the network.");
        }
        if (activationFunction == null) {
            throw new ArgumentNullException("activationFunction");
        }
        this.netInputDim = netTopology[0];
        this.netOutputDim = netTopology[netTopology.length - 1];
        this.layers = new ArrayList<List<Neuron>>(netTopology.length);
        ArrayList<Neuron> inputLayer = new ArrayList<Neuron>(netTopology[0]);
        for (int n = 0; n < netTopology[0]; ++n) {
            Neuron neuron = new Neuron((ActivationFunction)new ActivationLinear(), 1, biasInput, random);
            double[] weights = neuron.getWeights();
            weights[0] = 1.0;
            weights[1] = 0.0;
            inputLayer.add(neuron);
        }
        this.layers.add(inputLayer);
        try {
            for (int index = 1; index < netTopology.length; ++index) {
                ArrayList<Neuron> layer = new ArrayList<Neuron>(netTopology[index]);
                for (int n = 0; n < netTopology[index]; ++n) {
                    Neuron neuron = new Neuron(activationFunction.newInstance(), netTopology[index - 1], biasInput, random);
                    layer.add(neuron);
                }
                this.layers.add(layer);
                List<Neuron> prevLayer = this.layers.get(index - 1);
                for (int n = 0; n < prevLayer.size(); ++n) {
                    prevLayer.get(n).addAllNeurons(layer);
                }
            }
        }
        catch (InstantiationException e) {
            throw new IllegalArgumentException("Failed to create activation function instance.", e);
        }
        catch (IllegalAccessException e) {
            throw new IllegalArgumentException("Failed to create activation function instance.", e);
        }
    }

    @Override
    public List<Neuron> getLayerUnits(int layerIndex) {
        return Collections.unmodifiableList(this.layers.get(layerIndex));
    }

    @Override
    public int getLayersCount() {
        return this.layers.size();
    }

    @Override
    public double[] feedForward(double[] inputPattern) {
        if (inputPattern == null || inputPattern.length != this.netInputDim) {
            throw new IllegalArgumentException("Specified input pattern vector is null or does not match network input dimension.");
        }
        double[] layerOutput = null;
        double[] layerInput = inputPattern;
        for (int layerIndex = 0; layerIndex < this.layers.size(); ++layerIndex) {
            List<Neuron> layer = this.layers.get(layerIndex);
            int layerSize = layer.size();
            layerOutput = new double[layerSize];
            for (int n = 0; n < layerSize; ++n) {
                layerOutput[n] = layerIndex == 0 ? layer.get(n).processInput(new double[]{layerInput[n]}) : layer.get(n).processInput(layerInput);
            }
            layerInput = Arrays.copyOf(layerOutput, layerOutput.length);
        }
        this.currentNetOutput = Arrays.copyOf(layerOutput, layerOutput.length);
        return this.currentNetOutput;
    }

    @Override
    public double[] getOutput() {
        if (this.currentNetOutput == null) {
            return new double[this.netOutputDim];
        }
        return this.currentNetOutput;
    }

    @Override
    public void reset() {
        this.currentNetOutput = null;
        for (List<Neuron> layer : this.layers) {
            for (Neuron neuron : layer) {
                neuron.reset();
            }
        }
    }

    @Override
    public int getNetInputSize() {
        return this.netInputDim;
    }

    @Override
    public int getNetOutputSize() {
        return this.netOutputDim;
    }
}

