/*
 * Decompiled with CFR 0.152.
 */
package mulan.classifier.neural;

import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.Set;
import mulan.classifier.neural.DataPair;
import mulan.classifier.neural.MMPUpdateRuleBase;
import mulan.classifier.neural.model.Neuron;
import mulan.evaluation.loss.RankingLossFunction;

public class MMPRandomizedUpdateRule
extends MMPUpdateRuleBase {
    public MMPRandomizedUpdateRule(List<Neuron> perceptrons, RankingLossFunction lossMeasure) {
        super(perceptrons, lossMeasure);
    }

    @Override
    protected double[] computeUpdateParameters(DataPair example, double[] confidences, double loss) {
        double weight;
        int numLabels = example.getOutput().length;
        boolean[] trueOutput = example.getOutputBoolean();
        HashSet<Integer> relevant = new HashSet<Integer>();
        HashSet<Integer> irrelevant = new HashSet<Integer>();
        for (int index = 0; index < numLabels; ++index) {
            if (trueOutput[index]) {
                relevant.add(index);
                continue;
            }
            irrelevant.add(index);
        }
        double weightsSum = 0.0;
        HashMap<int[], Double> weightedErrorSet = new HashMap<int[], Double>();
        Random rnd = new Random();
        Iterator i$ = relevant.iterator();
        while (i$.hasNext()) {
            int rLabel = (Integer)i$.next();
            Iterator<Object> i$2 = irrelevant.iterator();
            while (i$2.hasNext()) {
                int irLabel = (Integer)i$2.next();
                if (!(confidences[rLabel] <= confidences[irLabel])) continue;
                weight = rnd.nextDouble();
                weightsSum += weight;
                weightedErrorSet.put(new int[]{rLabel, irLabel}, weight);
            }
        }
        double[] params = new double[numLabels];
        Set labelPairs = weightedErrorSet.keySet();
        for (int[] pair : labelPairs) {
            weight = (Double)weightedErrorSet.get(pair);
            int n = pair[0];
            params[n] = params[n] + weight * loss;
            int n2 = pair[1];
            params[n2] = params[n2] - weight * loss;
        }
        return params;
    }
}

