/*
 * Decompiled with CFR 0.152.
 */
package mulan.classifier.neural;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import mulan.classifier.neural.model.ActivationFunction;
import mulan.classifier.neural.model.NeuralNet;
import mulan.classifier.neural.model.Neuron;

public class BPMLLAlgorithm {
    private final NeuralNet neuralNet;
    private final double weightsDecayCost;

    public BPMLLAlgorithm(NeuralNet neuralNet, double weightsDecayCost) {
        if (neuralNet == null) {
            throw new IllegalArgumentException("The passed neural network model is null.");
        }
        if (weightsDecayCost <= 0.0 || weightsDecayCost > 1.0) {
            throw new IllegalArgumentException("The weights decay regularization cost term must be greater than 0 and no more than 1. The passed value is : " + weightsDecayCost);
        }
        this.neuralNet = neuralNet;
        this.weightsDecayCost = weightsDecayCost;
    }

    public NeuralNet getNetwork() {
        return this.neuralNet;
    }

    public double getWeightsDecayCost() {
        return this.weightsDecayCost;
    }

    public double learn(double[] inputPattern, double[] expectedLabels, double learningRate) {
        if (inputPattern == null || inputPattern.length != this.neuralNet.getNetInputSize()) {
            throw new IllegalArgumentException("Specified input pattern vector is null or does not match the input dimension of underlying neural network model.");
        }
        if (expectedLabels == null || expectedLabels.length != this.neuralNet.getNetOutputSize()) {
            throw new IllegalArgumentException("Specified expected labels vector is null or does not match the output dimension of underlying neural network model.");
        }
        double[] networkOutputs = this.neuralNet.feedForward(inputPattern);
        double[] outputErrors = this.computeErrorsForNeurons(networkOutputs, expectedLabels);
        if (outputErrors == null) {
            return Double.NaN;
        }
        double weightsSquareSum = 0.0;
        int layersCount = this.neuralNet.getLayersCount();
        for (int layerIndex = layersCount - 1; layerIndex > 0; --layerIndex) {
            List<Neuron> layer = this.neuralNet.getLayerUnits(layerIndex);
            if (layerIndex == layersCount - 1) {
                this.computeOutputLayerErrorTerms(layer, outputErrors);
            } else {
                List<Neuron> nextLayer = this.neuralNet.getLayerUnits(layerIndex + 1);
                this.computeHiddenLayerErrorTerms(layer, nextLayer);
            }
            List<Neuron> previousLayer = this.neuralNet.getLayerUnits(layerIndex - 1);
            double[] previousLayerOut = new double[previousLayer.size()];
            int previousLayerSize = previousLayer.size();
            for (int n = 0; n < previousLayerSize; ++n) {
                previousLayerOut[n] = previousLayer.get(n).getOutput();
            }
            for (Neuron neuron : layer) {
                double[] weights;
                for (double weight : weights = neuron.getWeights()) {
                    weightsSquareSum += weight * weight;
                }
            }
            this.updateWeights(layer, previousLayerOut, learningRate);
        }
        double globalError = 0.0;
        for (double error : outputErrors) {
            globalError += Math.abs(error);
        }
        return globalError += this.weightsDecayCost * 0.5 * weightsSquareSum;
    }

    public double getNetworkError(double[] inputPattern, double[] expectedLabels) {
        double[] networkOutputs = this.neuralNet.feedForward(inputPattern);
        double[] outputErrors = this.computeErrorsForNeurons(networkOutputs, expectedLabels);
        if (outputErrors == null) {
            return Double.NaN;
        }
        double weightsSquareSum = 0.0;
        int layersCount = this.neuralNet.getLayersCount();
        for (int layerIndex = 1; layerIndex < layersCount; ++layerIndex) {
            List<Neuron> layer = this.neuralNet.getLayerUnits(layerIndex);
            for (Neuron neuron : layer) {
                double[] weights;
                for (double weight : weights = neuron.getWeights()) {
                    weightsSquareSum += weight * weight;
                }
            }
        }
        double globalError = 0.0;
        for (double error : outputErrors) {
            globalError += Math.abs(error);
        }
        return globalError += this.weightsDecayCost * 0.5 * weightsSquareSum;
    }

    private void updateWeights(List<Neuron> layer, double[] layerInputs, double learningRate) {
        int layerSize = layer.size();
        for (int n = 0; n < layerSize; ++n) {
            Neuron neuron = layer.get(n);
            double[] weights = neuron.getWeights();
            double error = neuron.getError();
            int inputsCount = layerInputs.length;
            double currentDelta = 0.0;
            for (int i = 0; i < inputsCount; ++i) {
                currentDelta = learningRate * error * layerInputs[i];
                int n2 = i;
                weights[n2] = weights[n2] + (currentDelta - this.weightsDecayCost * weights[i]);
            }
            currentDelta = learningRate * error * neuron.getBiasInput();
            int n3 = inputsCount;
            weights[n3] = weights[n3] + (currentDelta - this.weightsDecayCost * weights[inputsCount]);
        }
    }

    private void computeOutputLayerErrorTerms(List<Neuron> outLayer, double[] outputErrors) {
        int neuronsInLayer = outLayer.size();
        for (int n = 0; n < neuronsInLayer; ++n) {
            Neuron neuron = outLayer.get(n);
            ActivationFunction layerFunction = neuron.getActivationFunction();
            double errorTerm = outputErrors[n] * layerFunction.derivative(neuron.getNeuronInput());
            neuron.setError(errorTerm);
        }
    }

    private void computeHiddenLayerErrorTerms(List<Neuron> layer, List<Neuron> nextLayer) {
        int neuronsInLayer = layer.size();
        int nextLayerNeuronsCount = nextLayer.size();
        for (int n = 0; n < neuronsInLayer; ++n) {
            Neuron neuron = layer.get(n);
            double sum = 0.0;
            for (int k = 0; k < nextLayerNeuronsCount; ++k) {
                Neuron nextNeuron = nextLayer.get(k);
                double[] nextNeuronWeights = nextNeuron.getWeights();
                sum += nextNeuron.getError() * nextNeuronWeights[n];
            }
            ActivationFunction neuronFunction = neuron.getActivationFunction();
            double errorTerm = sum * neuronFunction.derivative(neuron.getNeuronInput());
            neuron.setError(errorTerm);
        }
    }

    private double[] computeErrorsForNeurons(double[] networkOutputs, double[] expectedLabels) {
        ArrayList<Integer> isLabel = new ArrayList<Integer>();
        ArrayList<Integer> isNotLabel = new ArrayList<Integer>();
        int labelsCount = expectedLabels.length;
        for (int index = 0; index < labelsCount; ++index) {
            if (expectedLabels[index] == 1.0) {
                isLabel.add(index);
                continue;
            }
            isNotLabel.add(index);
        }
        double[] neuronsErrors = null;
        if (isLabel.size() != 0 && isNotLabel.size() != 0) {
            neuronsErrors = new double[labelsCount];
            for (int index = 0; index < labelsCount; ++index) {
                Iterator i$;
                double error = 0.0;
                if (isLabel.contains(index)) {
                    i$ = isNotLabel.iterator();
                    while (i$.hasNext()) {
                        int isNotLabelIndex = (Integer)i$.next();
                        error += Math.exp(-(networkOutputs[index] - networkOutputs[isNotLabelIndex]));
                    }
                } else {
                    i$ = isLabel.iterator();
                    while (i$.hasNext()) {
                        int isLabelIndex = (Integer)i$.next();
                        error -= Math.exp(-(networkOutputs[isLabelIndex] - networkOutputs[index]));
                    }
                }
                neuronsErrors[index] = error *= 1.0 / (double)(isLabel.size() * isNotLabel.size());
            }
        }
        return neuronsErrors;
    }
}

