/*
 * Decompiled with CFR 0.152.
 */
package adams.ml.cntk.modelapplier;

import adams.core.License;
import adams.core.Utils;
import adams.core.annotation.MixedCopyright;
import adams.core.logging.Logger;
import adams.core.logging.LoggingHelper;
import adams.data.image.AbstractImageContainer;
import adams.ml.cntk.modelapplier.AbstractModelApplier;
import adams.ml.cntk.predictionpostprocessor.PassThrough;
import adams.ml.cntk.predictionpostprocessor.PredictionPostProcessor;
import com.microsoft.CNTK.NDShape;
import com.microsoft.CNTK.Variable;
import java.awt.Color;
import java.awt.Image;
import java.awt.image.BufferedImage;
import java.util.logging.Level;

@MixedCopyright(author="CNTK", copyright="Microsoft", license=License.MIT, url="https://github.com/Microsoft/CNTK/blob/v2.0/Tests/EndToEndTests/EvalClientTests/JavaEvalTest/src/Main.java", note="Original code based on CNTK example")
public class DefaultImageApplier
extends AbstractModelApplier<AbstractImageContainer, float[]> {
    private static final long serialVersionUID = 7933924670965842681L;
    protected PredictionPostProcessor m_PostProcessor;

    public String globalInfo() {
        return "Applies the model to images and returns the score. Images get scaled according to the model inputs.\n" + this.m_ModelLoader.automaticOrderInfo();
    }

    @Override
    public void defineOptions() {
        super.defineOptions();
        this.m_OptionManager.add("post-processor", "postProcessor", (Object)this.getDefaultPostProcessor());
    }

    protected PredictionPostProcessor getDefaultPostProcessor() {
        return new PassThrough();
    }

    public void setPostProcessor(PredictionPostProcessor value) {
        this.m_PostProcessor = value;
        this.reset();
    }

    public PredictionPostProcessor getPostProcessor() {
        return this.m_PostProcessor;
    }

    public String postProcessorTipText() {
        return "The post-processor to apply to the predictions.";
    }

    @Override
    public Class accepts() {
        return AbstractImageContainer.class;
    }

    @Override
    public Class generates() {
        return float[].class;
    }

    @Override
    protected float[] doApplyModel(AbstractImageContainer input) {
        Variable inputVar = (Variable)this.m_Wrapper.getModel().getArguments().get(0);
        NDShape inputShape = inputVar.getShape();
        int imageWidth = (int)inputShape.getDimensions()[0];
        int imageHeight = (int)inputShape.getDimensions()[1];
        int imageChannels = (int)inputShape.getDimensions()[2];
        int imageSize = (int)inputShape.getTotalSize();
        if (this.isLoggingEnabled()) {
            this.getLogger().info("imageWidth=" + imageWidth);
            this.getLogger().info("imageHeight=" + imageHeight);
            this.getLogger().info("imageChannels=" + imageChannels);
            this.getLogger().info("imageSize=" + imageSize);
        }
        BufferedImage bmp = input.toBufferedImage();
        Image resized = bmp.getScaledInstance(imageWidth, imageHeight, 1);
        BufferedImage bImg = new BufferedImage(resized.getWidth(null), resized.getHeight(null), 1);
        bImg.getGraphics().drawImage(resized, 0, 0, null);
        float[] resizedCHW = new float[imageSize];
        int i = 0;
        for (int c = 0; c < imageChannels; ++c) {
            for (int h = 0; h < bImg.getHeight(); ++h) {
                for (int w = 0; w < bImg.getWidth(); ++w) {
                    Color color = new Color(bImg.getRGB(w, h));
                    resizedCHW[i] = c == 0 ? (float)color.getBlue() : (c == 1 ? (float)color.getGreen() : (float)color.getRed());
                    ++i;
                }
            }
        }
        if (LoggingHelper.isAtLeast((Logger)this.getLogger(), (Level)Level.FINE)) {
            this.getLogger().fine("resizedCHW=" + Utils.arrayToString((Object)resizedCHW));
        }
        try {
            float[] preds = this.m_Wrapper.predict(resizedCHW);
            return this.m_PostProcessor.postProcessPrediction(preds);
        }
        catch (Exception e) {
            this.getLogger().log(Level.SEVERE, "Failed to make prediction!", (Throwable)e);
            return new float[0];
        }
    }
}

