/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.modality.cv.translator;

import ai.djl.Model;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.VisionLanguageInput;
import ai.djl.modality.cv.output.BoundingBox;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.output.Rectangle;
import ai.djl.modality.cv.translator.BaseImagePreProcessor;
import ai.djl.modality.cv.translator.BaseImageTranslator;
import ai.djl.modality.cv.translator.YoloTranslator;
import ai.djl.modality.nlp.NlpUtils;
import ai.djl.modality.nlp.preprocess.LowerCaseConvertor;
import ai.djl.modality.nlp.preprocess.PunctuationSeparator;
import ai.djl.modality.nlp.preprocess.TextCleaner;
import ai.djl.modality.nlp.preprocess.TextProcessor;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.translate.ArgumentsUtil;
import ai.djl.translate.NoBatchifyTranslator;
import ai.djl.translate.NoopTranslator;
import ai.djl.translate.TranslateException;
import ai.djl.translate.TranslatorContext;
import ai.djl.util.JsonUtils;
import ai.djl.util.Pair;
import ai.djl.util.Utils;
import com.google.gson.reflect.TypeToken;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.Reader;
import java.lang.reflect.Type;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

public class YoloWorldTranslator
implements NoBatchifyTranslator<VisionLanguageInput, DetectedObjects> {
    private static final int MAX_DETECTION = 300;
    private static final int[] AXIS_0 = new int[]{0};
    private SimpleBpeTokenizer tokenizer;
    private BaseImageTranslator<?> imageProcessor;
    private Predictor<NDList, NDList> predictor;
    private String clipModelPath;
    private float threshold;
    private float nmsThreshold;

    YoloWorldTranslator(Builder builder) {
        this.imageProcessor = new BaseImagePreProcessor(builder);
        this.threshold = builder.threshold;
        this.nmsThreshold = builder.nmsThreshold;
        this.clipModelPath = builder.clipModelPath;
    }

    @Override
    public void prepare(TranslatorContext ctx) throws Exception {
        Model model = ctx.getModel();
        Path modelPath = model.getModelPath();
        Path path = Paths.get(this.clipModelPath, new String[0]);
        if (!path.isAbsolute() && Files.notExists(path, new LinkOption[0])) {
            path = modelPath.resolve(this.clipModelPath);
        }
        if (!Files.exists(path, new LinkOption[0])) {
            throw new IOException("clip model not found: " + this.clipModelPath);
        }
        NDManager manager = ctx.getNDManager();
        Model clip = manager.getEngine().newModel("clip", manager.getDevice());
        clip.load(path);
        this.predictor = clip.newPredictor(new NoopTranslator(null));
        model.getNDManager().attachInternal(NDManager.nextUid(), this.predictor);
        model.getNDManager().attachInternal(NDManager.nextUid(), clip);
        this.tokenizer = SimpleBpeTokenizer.newInstance(modelPath);
    }

    @Override
    public NDList processInput(TranslatorContext ctx, VisionLanguageInput input) throws TranslateException {
        NDManager manager = ctx.getNDManager();
        String[] candidates = input.getCandidates();
        if (candidates == null || candidates.length == 0) {
            throw new TranslateException("Missing candidates in input");
        }
        int[][] tokenIds = this.tokenizer.batchEncode(candidates);
        NDArray textFeature = (NDArray)this.predictor.predict(new NDList(manager.create(tokenIds))).get(0);
        Image img = input.getImage();
        NDList imageFeatures = this.imageProcessor.processInput(ctx, img);
        NDArray array = ((NDArray)imageFeatures.get(0)).expandDims(0);
        ctx.setAttachment("candidates", candidates);
        return new NDList(textFeature, array);
    }

    @Override
    public DetectedObjects processOutput(TranslatorContext ctx, NDList list) {
        List<String> classes = Arrays.asList((String[])ctx.getAttachment("candidates"));
        int width = (Integer)ctx.getAttachment("width");
        int height = (Integer)ctx.getAttachment("height");
        NDArray pred = (NDArray)list.get(0);
        pred = pred.squeeze(0);
        int boxIndex = classes.size() + 4;
        NDArray candidates = pred.get("4:" + boxIndex, new Object[0]).max(AXIS_0).gt(Float.valueOf(this.threshold));
        pred = pred.transpose();
        NDArray sub = pred.get("..., :4", new Object[0]);
        sub = YoloTranslator.xywh2xyxy(sub);
        pred = sub.concat(pred.get("..., 4:", new Object[0]), -1);
        pred = pred.get(candidates);
        NDList split = pred.split(new long[]{4L, boxIndex}, 1);
        NDArray box = (NDArray)split.get(0);
        int numBox = Math.toIntExact(box.getShape().get(0));
        float[] buf = box.toFloatArray();
        float[] confidences = ((NDArray)split.get(1)).toFloatArray();
        long[] ids = ((NDArray)split.get(1)).argMax(1).toLongArray();
        ArrayList<Rectangle> boxes = new ArrayList<Rectangle>(numBox);
        ArrayList<Double> scores = new ArrayList<Double>(numBox);
        for (int i = 0; i < numBox; ++i) {
            float xPos = buf[i * 4];
            float yPos = buf[i * 4 + 1];
            float w = buf[i * 4 + 2] - xPos;
            float h = buf[i * 4 + 3] - yPos;
            Rectangle rect = new Rectangle(xPos, yPos, w, h);
            boxes.add(rect);
            scores.add(Double.valueOf(confidences[i]));
        }
        List<Integer> nms = Rectangle.nms(boxes, scores, this.nmsThreshold);
        if (nms.size() > 300) {
            nms = nms.subList(0, 300);
        }
        ArrayList<String> retClasses = new ArrayList<String>();
        ArrayList<Double> retProbs = new ArrayList<Double>();
        ArrayList<BoundingBox> retBB = new ArrayList<BoundingBox>();
        for (int index : nms) {
            int id = (int)ids[index];
            retClasses.add(classes.get(id));
            retProbs.add(Double.valueOf(confidences[id]));
            Rectangle rect = (Rectangle)boxes.get(index);
            rect = new Rectangle(rect.getX() / (double)width, rect.getY() / (double)height, rect.getWidth() / (double)width, rect.getHeight() / (double)height);
            retBB.add(rect);
        }
        return new DetectedObjects(retClasses, retProbs, retBB);
    }

    public static Builder builder() {
        return new Builder();
    }

    public static Builder builder(Map<String, ?> arguments) {
        Builder builder = YoloWorldTranslator.builder();
        builder.configPreProcess(arguments);
        builder.configPostProcess(arguments);
        return builder;
    }

    public static class Builder
    extends BaseImageTranslator.BaseBuilder<Builder> {
        float threshold = 0.25f;
        float nmsThreshold = 0.7f;
        String clipModelPath = "clip.pt";

        @Override
        protected Builder self() {
            return this;
        }

        public Builder optThreshold(float threshold) {
            this.threshold = threshold;
            return this.self();
        }

        public Builder optNmsThreshold(float nmsThreshold) {
            this.nmsThreshold = nmsThreshold;
            return this;
        }

        public Builder optClipModelPath(String clipModelPath) {
            this.clipModelPath = clipModelPath;
            return this;
        }

        @Override
        protected void configPostProcess(Map<String, ?> arguments) {
            super.configPostProcess(arguments);
            this.optThreshold(ArgumentsUtil.floatValue(arguments, "threshold", this.threshold));
            this.optNmsThreshold(ArgumentsUtil.floatValue(arguments, "nmsThreshold", this.nmsThreshold));
            this.optClipModelPath(ArgumentsUtil.stringValue(arguments, "clipModelPath", "clip.pt"));
        }

        public YoloWorldTranslator build() {
            return new YoloWorldTranslator(this);
        }
    }

    static final class SimpleBpeTokenizer {
        private static final int MIN_CONTEXT_LENGTH = 77;
        private static final int MAX_CONTEXT_LENGTH = 512;
        private static final Type MAP_TYPE = new TypeToken<Map<String, Integer>>(){}.getType();
        private Map<String, Integer> vocabulary;
        private Map<Pair<String, String>, Integer> ranks;
        private int sot;
        private int eot;

        SimpleBpeTokenizer(Map<String, Integer> vocabulary, Map<Pair<String, String>, Integer> ranks) {
            this.vocabulary = vocabulary;
            this.ranks = ranks;
            this.sot = vocabulary.get("<|startoftext|>");
            this.eot = vocabulary.get("<|endoftext|>");
        }

        static SimpleBpeTokenizer newInstance(Path modelPath) throws IOException {
            Path vocab = modelPath.resolve("vocab.json");
            Path merges = modelPath.resolve("merges.txt");
            ConcurrentHashMap<Pair<String, String>, Integer> ranks = new ConcurrentHashMap<Pair<String, String>, Integer>();
            List<String> lines = Utils.readLines(merges);
            lines = lines.subList(1, lines.size());
            int index = 0;
            for (String line : lines) {
                String[] tok = line.split(" ");
                ranks.put(new Pair<String, String>(tok[0], tok[1]), index++);
            }
            try (BufferedReader reader = Files.newBufferedReader(vocab);){
                Map vocabulary = (Map)JsonUtils.GSON.fromJson((Reader)reader, MAP_TYPE);
                SimpleBpeTokenizer simpleBpeTokenizer = new SimpleBpeTokenizer(vocabulary, ranks);
                return simpleBpeTokenizer;
            }
        }

        int[][] batchEncode(String[] inputs) {
            ArrayList<List<Integer>> list = new ArrayList<List<Integer>>();
            int contextLength = 0;
            for (String string : inputs) {
                List<Integer> ids = this.encode(string);
                int size = ids.size();
                if (size > 512) {
                    ids = ids.subList(0, 512);
                }
                contextLength = Math.max(contextLength, size);
                list.add(ids);
            }
            contextLength = Math.max(contextLength, 77);
            int[][] tokenIds = new int[inputs.length][contextLength];
            int row = 0;
            for (List list2 : list) {
                for (int col = 0; col < list2.size(); ++col) {
                    tokenIds[row][col] = (Integer)list2.get(col);
                }
                ++row;
            }
            return tokenIds;
        }

        List<Integer> encode(String text) {
            List<String> tokens = new ArrayList<String>(Collections.singletonList(text));
            ArrayList<TextProcessor> processors = new ArrayList<TextProcessor>();
            processors.add(new LowerCaseConvertor());
            processors.add(new TextCleaner(NlpUtils::isWhiteSpace, ' '));
            processors.add(new PunctuationSeparator());
            for (TextProcessor processor : processors) {
                tokens = processor.preprocess(tokens);
            }
            ArrayList<Integer> idx = new ArrayList<Integer>();
            idx.add(this.sot);
            for (String token : tokens) {
                String bpe = this.bpe(token);
                idx.add(this.vocabulary.get(bpe));
            }
            idx.add(this.eot);
            return idx;
        }

        private String bpe(String token) {
            Pair<String, String> min;
            char[] chars = token.toCharArray();
            ArrayList<String> word = new ArrayList<String>(chars.length);
            for (char c : chars) {
                word.add(String.valueOf(c));
            }
            word.set(word.size() - 1, (String)word.get(word.size() - 1) + "</w>");
            Set<Pair<String, String>> pairs = this.getPairs(word);
            if (pairs.isEmpty()) {
                return token + "</w>";
            }
            while (this.ranks.containsKey(min = Collections.min(pairs, (o1, o2) -> Integer.compare(this.ranks.getOrDefault(o1, Integer.MAX_VALUE), this.ranks.getOrDefault(o2, Integer.MAX_VALUE))))) {
                ArrayList newWord = new ArrayList();
                String first = min.getKey();
                String second = min.getValue();
                int i = 0;
                while (i < word.size()) {
                    List subList = word.subList(i, word.size());
                    int j = subList.indexOf(first);
                    if (j < 0) {
                        newWord.addAll(word.subList(i, word.size()));
                        break;
                    }
                    newWord.addAll(word.subList(i, j += i));
                    i = j;
                    if (((String)word.get(i)).equals(first) && i < word.size() - 1 && ((String)word.get(i + 1)).equals(second)) {
                        newWord.add(first + second);
                        i += 2;
                        continue;
                    }
                    newWord.add((String)word.get(i));
                    ++i;
                }
                if ((word = newWord).size() == 1) break;
                pairs = this.getPairs(word);
            }
            return String.join((CharSequence)" ", word);
        }

        private Set<Pair<String, String>> getPairs(List<String> word) {
            if (word.size() < 2) {
                return Collections.emptySet();
            }
            HashSet<Pair<String, String>> pairs = new HashSet<Pair<String, String>>();
            String prev = word.get(0);
            for (int i = 1; i < word.size(); ++i) {
                pairs.add(new Pair<String, String>(prev, word.get(i)));
                prev = word.get(i);
            }
            return pairs;
        }
    }
}

