/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.mxnet.zoo.cv.poseestimation;

import ai.djl.modality.cv.Joints;
import ai.djl.modality.cv.SimplePoseTranslator;
import ai.djl.modality.cv.transform.Normalize;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.repository.Anchor;
import ai.djl.repository.Artifact;
import ai.djl.repository.MRL;
import ai.djl.repository.Repository;
import ai.djl.repository.zoo.BaseModelLoader;
import ai.djl.translate.Pipeline;
import ai.djl.translate.Transform;
import ai.djl.translate.Translator;
import java.awt.image.BufferedImage;
import java.util.Map;

public class SimplePoseModelLoader
extends BaseModelLoader<BufferedImage, Joints> {
    private static final Anchor BASE_ANCHOR = MRL.Model.CV.POSE_ESTIMATION;
    private static final String GROUP_ID = "ai.djl.mxnet";
    private static final String ARTIFACT_ID = "simple_pose";
    private static final String VERSION = "0.0.1";

    public SimplePoseModelLoader(Repository repository) {
        super(repository, new MRL(BASE_ANCHOR, GROUP_ID, ARTIFACT_ID), VERSION);
    }

    public Translator<BufferedImage, Joints> getTranslator(Artifact artifact) {
        Map arguments = artifact.getArguments();
        int width = arguments.getOrDefault("width", 192.0).intValue();
        int height = arguments.getOrDefault("height", 256.0).intValue();
        double threshold = arguments.getOrDefault("threshold", 0.2);
        Pipeline pipeline = new Pipeline();
        pipeline.add((Transform)new Resize(width, height)).add((Transform)new ToTensor()).add((Transform)new Normalize(new float[]{0.485f, 0.456f, 0.406f}, new float[]{0.229f, 0.224f, 0.225f}));
        return ((SimplePoseTranslator.Builder)new SimplePoseTranslator.Builder().setPipeline(pipeline)).optThreshold((float)threshold).build();
    }
}

