/*
 * Decompiled with CFR 0.152.
 */
package org.canova.api.io.filters;

import java.net.URI;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Random;
import org.canova.api.io.filters.RandomPathFilter;
import org.canova.api.io.labels.PathLabelGenerator;
import org.canova.api.writable.Writable;

public class BalancedPathFilter
extends RandomPathFilter {
    protected PathLabelGenerator labelGenerator;
    protected int maxLabels = 0;
    protected int maxPathsPerLabel = 0;

    public BalancedPathFilter(Random random, String[] extensions, PathLabelGenerator labelGenerator) {
        this(random, extensions, labelGenerator, 0, 0, 0);
    }

    public BalancedPathFilter(Random random, String[] extensions, PathLabelGenerator labelGenerator, int maxPaths, int maxLabels, int maxPathsPerLabel) {
        super(random, extensions, maxPaths);
        this.labelGenerator = labelGenerator;
        this.maxLabels = maxLabels;
        this.maxPathsPerLabel = maxPathsPerLabel;
    }

    @Override
    public URI[] filter(URI[] paths) {
        paths = super.filter(paths);
        HashMap<Writable, Integer> labelsCount = new HashMap<Writable, Integer>();
        for (int i = 0; i < paths.length; ++i) {
            URI path = paths[i];
            Writable label = this.labelGenerator.getLabelForPath(path);
            Integer count = (Integer)labelsCount.get(label);
            if (count == null) {
                if (this.maxLabels > 0 && labelsCount.size() >= this.maxLabels) continue;
                count = 0;
            }
            labelsCount.put(label, count + 1);
        }
        int minCount = Integer.MAX_VALUE;
        for (Integer count : labelsCount.values()) {
            if (minCount <= count) continue;
            minCount = count;
        }
        if (this.maxPathsPerLabel > 0 && minCount > this.maxPathsPerLabel) {
            minCount = this.maxPathsPerLabel;
        }
        labelsCount.clear();
        ArrayList<URI> newpaths = new ArrayList<URI>();
        for (int i = 0; i < paths.length; ++i) {
            URI path = paths[i];
            Writable label = this.labelGenerator.getLabelForPath(path);
            Integer count = (Integer)labelsCount.get(label);
            if (count == null) {
                if (this.maxLabels > 0 && labelsCount.size() >= this.maxLabels) continue;
                count = 0;
            }
            labelsCount.put(label, count + 1);
            if (count >= minCount) continue;
            newpaths.add(path);
        }
        Collections.shuffle(newpaths, this.random);
        return newpaths.toArray(new URI[newpaths.size()]);
    }
}

