/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.parameterserver.distributed.util;

import java.net.InterfaceAddress;
import java.net.NetworkInterface;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import lombok.NonNull;
import org.apache.commons.net.util.SubnetUtils;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.parameterserver.distributed.util.NetworkInformation;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class NetworkOrganizer {
    private static final Logger log = LoggerFactory.getLogger(NetworkOrganizer.class);
    protected List<NetworkInformation> informationCollection;
    protected String networkMask;
    protected VirtualTree tree = new VirtualTree();

    protected NetworkOrganizer(@NonNull Collection<NetworkInformation> infoSet) {
        this(infoSet, null);
        if (infoSet == null) {
            throw new NullPointerException("infoSet");
        }
    }

    public NetworkOrganizer(@NonNull Collection<NetworkInformation> infoSet, String mask) {
        if (infoSet == null) {
            throw new NullPointerException("infoSet");
        }
        this.informationCollection = new ArrayList<NetworkInformation>(infoSet);
        this.networkMask = mask;
    }

    public NetworkOrganizer(@NonNull String networkMask) {
        if (networkMask == null) {
            throw new NullPointerException("networkMask");
        }
        this.informationCollection = this.buildLocalInformation();
        this.networkMask = networkMask;
    }

    protected List<NetworkInformation> buildLocalInformation() {
        ArrayList<NetworkInformation> list = new ArrayList<NetworkInformation>();
        NetworkInformation netInfo = new NetworkInformation();
        try {
            ArrayList<NetworkInterface> interfaces = Collections.list(NetworkInterface.getNetworkInterfaces());
            for (NetworkInterface networkInterface : interfaces) {
                if (!networkInterface.isUp()) continue;
                for (InterfaceAddress address : networkInterface.getInterfaceAddresses()) {
                    String addr = address.getAddress().getHostAddress();
                    if (addr == null || addr.isEmpty() || addr.contains(":")) continue;
                    netInfo.getIpAddresses().add(addr);
                }
            }
            list.add(netInfo);
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
        return list;
    }

    public String getMatchingAddress() {
        List<String> list;
        if (this.informationCollection.size() > 1) {
            this.informationCollection = this.buildLocalInformation();
        }
        if ((list = this.getSubset(1)).size() < 1) {
            throw new ND4JIllegalStateException("Unable to find network interface matching requested mask: " + this.networkMask);
        }
        if (list.size() > 1) {
            log.warn("We have {} local IPs matching given netmask [{}]", (Object)list.size(), (Object)this.networkMask);
        }
        return list.get(0);
    }

    public List<String> getSubset(int numShards) {
        return this.getSubset(numShards, null);
    }

    public List<String> getSubset(int numShards, Collection<String> primary) {
        if (this.networkMask == null) {
            return this.getIntersections(numShards, primary);
        }
        ArrayList<String> addresses = new ArrayList<String>();
        SubnetUtils utils = new SubnetUtils(this.networkMask);
        Collections.shuffle(this.informationCollection);
        for (NetworkInformation information : this.informationCollection) {
            for (String ip : information.getIpAddresses()) {
                if (primary != null && primary.contains(ip)) continue;
                if (utils.getInfo().isInRange(ip)) {
                    log.debug("Picked {} as {}", (Object)ip, (Object)(primary == null ? "Shard" : "Backup"));
                    addresses.add(ip);
                }
                if (addresses.size() < numShards) continue;
                break;
            }
            if (addresses.size() < numShards) continue;
            break;
        }
        return addresses;
    }

    protected static String convertIpToOctets(@NonNull String ip) {
        if (ip == null) {
            throw new NullPointerException("ip");
        }
        String[] octets = ip.split("\\.");
        if (octets.length != 4) {
            throw new UnsupportedOperationException();
        }
        StringBuilder builder = new StringBuilder();
        for (int i = 0; i < 3; ++i) {
            builder.append(NetworkOrganizer.toBinaryOctet(octets[i])).append(".");
        }
        builder.append(NetworkOrganizer.toBinaryOctet(octets[3]));
        return builder.toString();
    }

    protected static String toBinaryOctet(@NonNull Integer value) {
        if (value == null) {
            throw new NullPointerException("value");
        }
        if (value < 0 || value > 255) {
            throw new ND4JIllegalStateException("IP octets cant hold values below 0 or above 255");
        }
        String octetBase = Integer.toBinaryString(value);
        StringBuilder builder = new StringBuilder();
        for (int i = 0; i < 8 - octetBase.length(); ++i) {
            builder.append("0");
        }
        builder.append(octetBase);
        return builder.toString();
    }

    protected static String toBinaryOctet(@NonNull String value) {
        if (value == null) {
            throw new NullPointerException("value");
        }
        return NetworkOrganizer.toBinaryOctet(Integer.parseInt(value));
    }

    protected List<String> getIntersections(int numShards, Collection<String> primary) {
        if (primary == null) {
            for (NetworkInformation information : this.informationCollection) {
                for (String string : information.getIpAddresses()) {
                    String octet = NetworkOrganizer.convertIpToOctets(string);
                    this.tree.map(octet);
                }
            }
            String octetA = this.tree.getHottestNetworkA();
            ArrayList<String> candidates = new ArrayList<String>();
            AtomicInteger matchCount = new AtomicInteger(0);
            block2: for (NetworkInformation node : this.informationCollection) {
                for (String ip : node.getIpAddresses()) {
                    String octet = NetworkOrganizer.convertIpToOctets(ip);
                    if (!octet.startsWith(octetA)) continue;
                    matchCount.incrementAndGet();
                    candidates.add(ip);
                    continue block2;
                }
            }
            if (matchCount.get() != this.informationCollection.size()) {
                throw new ND4JIllegalStateException("Mismatching A class");
            }
            Collections.shuffle(candidates);
            return new ArrayList<String>(candidates.subList(0, Math.min(numShards, candidates.size())));
        }
        String octetA = this.tree.getHottestNetworkA();
        ArrayList<String> candidates = new ArrayList<String>();
        block4: for (NetworkInformation networkInformation : this.informationCollection) {
            for (String ip : networkInformation.getIpAddresses()) {
                String octet = NetworkOrganizer.convertIpToOctets(ip);
                if (!octet.startsWith(octetA) || primary.contains(ip)) continue;
                candidates.add(ip);
                continue block4;
            }
        }
        Collections.shuffle(candidates);
        return new ArrayList<String>(candidates.subList(0, Math.min(numShards, candidates.size())));
    }

    public static class VirtualNode {
        protected Map<Character, VirtualNode> nodes = new HashMap<Character, VirtualNode>();
        protected final Character ownChar;
        protected int counter = 0;
        protected VirtualNode parentNode;

        public VirtualNode(Character character, VirtualNode parentNode) {
            this.ownChar = character;
            this.parentNode = parentNode;
        }

        public void map(String[] chars, int position) {
            ++this.counter;
            if (position < chars.length) {
                Character ch = Character.valueOf(chars[position].charAt(0));
                if (!this.nodes.containsKey(ch)) {
                    this.nodes.put(ch, new VirtualNode(ch, this));
                }
                this.nodes.get(ch).map(chars, position + 1);
            }
        }

        protected int getNumDivergents() {
            if (this.nodes.size() == 0) {
                return 0;
            }
            AtomicInteger cnt = new AtomicInteger(this.nodes.size() - 1);
            for (VirtualNode node : this.nodes.values()) {
                cnt.addAndGet(node.getNumDivergents());
            }
            return cnt.get();
        }

        protected int getDiscriminatedCount() {
            if (this.nodes.size() == 0 && this.counter == 1) {
                return 0;
            }
            AtomicInteger cnt = new AtomicInteger(Math.max(0, this.counter - 1));
            for (VirtualNode node : this.nodes.values()) {
                cnt.addAndGet(node.getDiscriminatedCount());
            }
            return cnt.get();
        }

        protected int getCounter() {
            return this.counter;
        }

        protected VirtualNode getHottestNode(int threshold) {
            for (VirtualNode node : this.nodes.values()) {
                if (node.getCounter() < threshold) continue;
                return node.getHottestNode(threshold);
            }
            return this;
        }

        protected VirtualNode getHottestNode() {
            int max = 0;
            Character ch = null;
            for (VirtualNode node : this.nodes.values()) {
                if (node.getCounter() <= max) continue;
                ch = node.ownChar;
                max = node.getCounter();
            }
            return this.nodes.get(ch);
        }

        protected String rewind() {
            StringBuilder builder = new StringBuilder();
            VirtualNode lastNode = this;
            while ((lastNode = lastNode.parentNode) != null) {
                builder.append(lastNode.ownChar);
            }
            return builder.reverse().toString();
        }
    }

    public static class VirtualTree {
        protected Map<Character, VirtualNode> nodes = new HashMap<Character, VirtualNode>();

        public void map(@NonNull String string) {
            if (string == null) {
                throw new NullPointerException("string");
            }
            String[] chars = string.split("");
            Character ch = Character.valueOf(chars[0].charAt(0));
            if (ch.charValue() != '0' && ch.charValue() != '1') {
                throw new ND4JIllegalStateException("VirtualTree expects binary octets as input");
            }
            if (!this.nodes.containsKey(ch)) {
                this.nodes.put(ch, new VirtualNode(ch, null));
            }
            this.nodes.get(ch).map(chars, 1);
        }

        public int getUniqueBranches() {
            AtomicInteger cnt = new AtomicInteger(this.nodes.size());
            for (VirtualNode node : this.nodes.values()) {
                cnt.addAndGet(node.getNumDivergents());
            }
            return cnt.get();
        }

        public int getTotalBranches() {
            AtomicInteger cnt = new AtomicInteger(0);
            for (VirtualNode node : this.nodes.values()) {
                cnt.addAndGet(node.getCounter());
            }
            return cnt.get();
        }

        public String getHottestNetwork() {
            int max = 0;
            Character key = null;
            for (VirtualNode node : this.nodes.values()) {
                if (node.getCounter() <= max) continue;
                max = node.getCounter();
                key = node.ownChar;
            }
            VirtualNode topNode = this.nodes.get(key).getHottestNode(max);
            return topNode.rewind();
        }

        protected VirtualNode getHottestNode() {
            int max = 0;
            Character key = null;
            for (VirtualNode node : this.nodes.values()) {
                if (node.getCounter() <= max) continue;
                max = node.getCounter();
                key = node.ownChar;
            }
            return this.nodes.get(key);
        }

        public String getHottestNetworkA() {
            StringBuilder builder = new StringBuilder();
            boolean depth = false;
            VirtualNode startingNode = this.getHottestNode();
            if (startingNode == null) {
                throw new ND4JIllegalStateException("VirtualTree wasn't properly initialized, and doesn't have any information within");
            }
            builder.append(startingNode.ownChar);
            for (int i = 0; i < 7; ++i) {
                startingNode = startingNode.getHottestNode();
                builder.append(startingNode.ownChar);
            }
            return builder.toString();
        }

        public String getHottestNetworkAB() {
            int i;
            StringBuilder builder = new StringBuilder();
            boolean depth = false;
            VirtualNode startingNode = this.getHottestNode();
            if (startingNode == null) {
                throw new ND4JIllegalStateException("VirtualTree wasn't properly initialized, and doesn't have any information within");
            }
            builder.append(startingNode.ownChar);
            for (i = 0; i < 7; ++i) {
                startingNode = startingNode.getHottestNode();
                builder.append(startingNode.ownChar);
            }
            startingNode = startingNode.getHottestNode();
            builder.append(startingNode.ownChar);
            for (i = 0; i < 8; ++i) {
            }
            return builder.toString();
        }
    }
}

