/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.zoo.cv.classification;

import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.ImageClassificationTranslator;
import ai.djl.modality.cv.transform.CenterCrop;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.repository.Anchor;
import ai.djl.repository.Artifact;
import ai.djl.repository.MRL;
import ai.djl.repository.Repository;
import ai.djl.repository.zoo.BaseModelLoader;
import ai.djl.translate.Pipeline;
import ai.djl.translate.Transform;
import ai.djl.translate.Translator;
import ai.djl.zoo.cv.classification.ResNetV1;
import java.awt.image.BufferedImage;
import java.io.IOException;
import java.nio.file.Path;
import java.util.List;
import java.util.Map;

public class ResNetModelLoader
extends BaseModelLoader<BufferedImage, Classifications> {
    private static final Anchor BASE_ANCHOR = MRL.Model.CV.IMAGE_CLASSIFICATION;
    private static final String GROUP_ID = "ai.djl.zoo";
    private static final String ARTIFACT_ID = "resnet";
    private static final String VERSION = "0.0.1";

    public ResNetModelLoader(Repository repository) {
        super(repository, new MRL(BASE_ANCHOR, GROUP_ID, ARTIFACT_ID), VERSION);
    }

    public Translator<BufferedImage, Classifications> getTranslator(Artifact artifact) {
        Map arguments = artifact.getArguments();
        List shape = (List)arguments.get("imageShape");
        int width = ((Double)shape.get(2)).intValue();
        int height = ((Double)shape.get(1)).intValue();
        Pipeline pipeline = new Pipeline();
        pipeline.add((Transform)new CenterCrop()).add((Transform)new Resize(width, height)).add((Transform)new ToTensor());
        return ((ImageClassificationTranslator.Builder)new ImageClassificationTranslator.Builder().setPipeline(pipeline)).setSynsetArtifactName("synset.txt").build();
    }

    protected Model loadModel(Artifact artifact, Path modelPath, Device device) throws IOException, MalformedModelException {
        Map arguments = artifact.getArguments();
        Shape shape = new Shape(((List)arguments.get("imageShape")).stream().mapToLong(Double::longValue).toArray());
        ResNetV1.Builder blockBuilder = new ResNetV1.Builder().setNumLayers((int)((Double)arguments.get("numLayers")).doubleValue()).setOutSize((long)((Double)arguments.get("outSize")).doubleValue()).setImageShape(shape);
        if (arguments.containsKey("batchNormMomentum")) {
            float batchNormMomentum = (float)((Double)arguments.get("batchNormMomentum")).doubleValue();
            blockBuilder.optBatchNormMomemtum(batchNormMomentum);
        }
        Block block = blockBuilder.build();
        Model model = Model.newInstance((Device)device);
        model.setBlock(block);
        model.load(modelPath, artifact.getName());
        return model;
    }
}

