/*
 * Decompiled with CFR 0.152.
 */
package mulan.classifier.neural;

import java.util.List;
import java.util.Map;
import mulan.classifier.MultiLabelOutput;
import mulan.classifier.neural.DataPair;
import mulan.classifier.neural.ModelUpdateRule;
import mulan.classifier.neural.model.Neuron;
import mulan.core.ArgumentNullException;
import mulan.evaluation.loss.RankingLossFunction;

public abstract class MMPUpdateRuleBase
implements ModelUpdateRule {
    private final List<Neuron> perceptrons;
    private final RankingLossFunction lossFunction;

    public MMPUpdateRuleBase(List<Neuron> perceptrons, RankingLossFunction loss) {
        if (perceptrons == null) {
            throw new ArgumentNullException("perceptrons");
        }
        if (loss == null) {
            throw new ArgumentNullException("lossMeasure");
        }
        this.perceptrons = perceptrons;
        this.lossFunction = loss;
    }

    @Override
    public final double process(DataPair example, Map<String, Object> params) {
        int numLabels = example.getOutput().length;
        int numFeatures = example.getInput().length;
        double[] dataInput = example.getInput();
        double[] confidences = new double[numLabels];
        for (int index = 0; index < numLabels; ++index) {
            Neuron perceptron = this.perceptrons.get(index);
            confidences[index] = perceptron.processInput(dataInput);
        }
        MultiLabelOutput mlOut = new MultiLabelOutput(confidences);
        double loss = this.lossFunction.computeLoss(mlOut.getRanking(), example.getOutputBoolean());
        if (loss != 0.0) {
            double[] updateParams = this.computeUpdateParameters(example, confidences, loss);
            for (int lIndex = 0; lIndex < numLabels; ++lIndex) {
                Neuron perceptron = this.perceptrons.get(lIndex);
                double[] weights = perceptron.getWeights();
                for (int iIndex = 0; iIndex < numFeatures; ++iIndex) {
                    int n = iIndex;
                    weights[n] = weights[n] + updateParams[lIndex] * dataInput[iIndex];
                }
            }
        }
        return loss;
    }

    protected abstract double[] computeUpdateParameters(DataPair var1, double[] var2, double var3);
}

