/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.zoo.cv.classification;

import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.ImageClassificationTranslator;
import ai.djl.modality.cv.transform.CenterCrop;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.modality.cv.util.NDImageUtils;
import ai.djl.nn.Block;
import ai.djl.repository.Anchor;
import ai.djl.repository.Artifact;
import ai.djl.repository.MRL;
import ai.djl.repository.Repository;
import ai.djl.repository.zoo.BaseModelLoader;
import ai.djl.translate.Pipeline;
import ai.djl.translate.Transform;
import ai.djl.translate.Translator;
import ai.djl.zoo.cv.classification.Mlp;
import java.awt.image.BufferedImage;
import java.io.IOException;
import java.nio.file.Path;
import java.util.Map;

public class MlpModelLoader
extends BaseModelLoader<BufferedImage, Classifications> {
    private static final Anchor BASE_ANCHOR = MRL.Model.CV.IMAGE_CLASSIFICATION;
    private static final String GROUP_ID = "ai.djl.zoo";
    private static final String ARTIFACT_ID = "mlp";
    private static final String VERSION = "0.0.1";

    public MlpModelLoader(Repository repository) {
        super(repository, new MRL(BASE_ANCHOR, GROUP_ID, ARTIFACT_ID), VERSION);
    }

    public Translator<BufferedImage, Classifications> getTranslator(Artifact artifact) {
        Map arguments = artifact.getArguments();
        int width = arguments.getOrDefault("width", 28.0).intValue();
        int height = arguments.getOrDefault("height", 28.0).intValue();
        String flag = arguments.getOrDefault("flag", NDImageUtils.Flag.COLOR.name());
        Pipeline pipeline = new Pipeline();
        pipeline.add((Transform)new CenterCrop()).add((Transform)new Resize(width, height)).add((Transform)new ToTensor());
        return ((ImageClassificationTranslator.Builder)((ImageClassificationTranslator.Builder)new ImageClassificationTranslator.Builder().optFlag(NDImageUtils.Flag.valueOf((String)flag))).setPipeline(pipeline)).setSynsetArtifactName("synset.txt").build();
    }

    protected Model loadModel(Artifact artifact, Path modelPath, Device device) throws IOException, MalformedModelException {
        Map arguments = artifact.getArguments();
        int width = arguments.getOrDefault("width", 28.0).intValue();
        int height = arguments.getOrDefault("height", 28.0).intValue();
        Model model = Model.newInstance((Device)device);
        model.setBlock((Block)new Mlp(width, height));
        model.load(modelPath, artifact.getName());
        return model;
    }
}

