/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.rl4j.mdp.ale;

import java.beans.ConstructorProperties;
import org.bytedeco.javacpp.IntPointer;
import org.bytedeco.javacpp.ale;
import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.space.ArrayObservationSpace;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.space.ObservationSpace;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ALEMDP
implements MDP<GameScreen, Integer, DiscreteSpace> {
    private static final Logger log = LoggerFactory.getLogger(ALEMDP.class);
    protected ale.ALEInterface ale;
    protected final int[] actions;
    protected final DiscreteSpace discreteSpace;
    protected final ObservationSpace<GameScreen> observationSpace;
    protected final String romFile;
    protected final boolean render;
    protected final Configuration configuration;
    protected double scaleFactor = 1.0;
    private byte[] screenBuffer;

    public ALEMDP(String romFile) {
        this(romFile, false);
    }

    public ALEMDP(String romFile, boolean render) {
        this(romFile, render, new Configuration(123, 0.0f, 0, 0, true));
    }

    public ALEMDP(String romFile, boolean render, Configuration configuration) {
        this.romFile = romFile;
        this.configuration = configuration;
        this.render = render;
        this.ale = new ale.ALEInterface();
        this.setupGame();
        IntPointer a = this.getConfiguration().minimalActionSet ? this.ale.getMinimalActionSet() : this.ale.getLegalActionSet();
        this.actions = new int[(int)a.limit()];
        a.get(this.actions);
        this.discreteSpace = new DiscreteSpace(this.actions.length);
        int[] shape = new int[]{(int)this.ale.getScreen().height(), (int)this.ale.getScreen().width(), 3};
        this.observationSpace = new ArrayObservationSpace(shape);
        this.screenBuffer = new byte[shape[0] * shape[1] * shape[2]];
    }

    public void setupGame() {
        Configuration conf = this.getConfiguration();
        this.ale.setInt("random_seed", conf.randomSeed);
        this.ale.setFloat("repeat_action_probability", conf.repeatActionProbability);
        this.ale.setBool("display_screen", this.render);
        this.ale.setBool("sound", this.render);
        this.ale.setInt("max_num_frames", conf.maxNumFrames);
        this.ale.setInt("max_num_frames_per_episode", conf.maxNumFramesPerEpisode);
        this.ale.loadROM(this.romFile);
    }

    public boolean isDone() {
        return this.ale.game_over();
    }

    public GameScreen reset() {
        this.ale.reset_game();
        this.ale.getScreenRGB(this.screenBuffer);
        return new GameScreen(this.screenBuffer);
    }

    public void close() {
        this.ale.deallocate();
    }

    public StepReply<GameScreen> step(Integer action) {
        double r = (double)this.ale.act(this.actions[action]) * this.scaleFactor;
        log.info(this.ale.getEpisodeFrameNumber() + " " + r + " " + action + " ");
        this.ale.getScreenRGB(this.screenBuffer);
        return new StepReply((Object)new GameScreen(this.screenBuffer), r, this.ale.game_over(), null);
    }

    public ObservationSpace<GameScreen> getObservationSpace() {
        return this.observationSpace;
    }

    public DiscreteSpace getActionSpace() {
        return this.discreteSpace;
    }

    public ALEMDP newInstance() {
        return new ALEMDP(this.romFile, this.render, this.configuration);
    }

    public String getRomFile() {
        return this.romFile;
    }

    public boolean isRender() {
        return this.render;
    }

    public Configuration getConfiguration() {
        return this.configuration;
    }

    public void setScaleFactor(double scaleFactor) {
        this.scaleFactor = scaleFactor;
    }

    public static class GameScreen
    implements Encodable {
        double[] array;

        public GameScreen(byte[] screen) {
            this.array = new double[screen.length];
            for (int i = 0; i < screen.length; ++i) {
                this.array[i] = (double)(screen[i] & 0xFF) / 255.0;
            }
        }

        public double[] toArray() {
            return this.array;
        }
    }

    public static final class Configuration {
        private final int randomSeed;
        private final float repeatActionProbability;
        private final int maxNumFrames;
        private final int maxNumFramesPerEpisode;
        private final boolean minimalActionSet;

        @ConstructorProperties(value={"randomSeed", "repeatActionProbability", "maxNumFrames", "maxNumFramesPerEpisode", "minimalActionSet"})
        public Configuration(int randomSeed, float repeatActionProbability, int maxNumFrames, int maxNumFramesPerEpisode, boolean minimalActionSet) {
            this.randomSeed = randomSeed;
            this.repeatActionProbability = repeatActionProbability;
            this.maxNumFrames = maxNumFrames;
            this.maxNumFramesPerEpisode = maxNumFramesPerEpisode;
            this.minimalActionSet = minimalActionSet;
        }

        public int getRandomSeed() {
            return this.randomSeed;
        }

        public float getRepeatActionProbability() {
            return this.repeatActionProbability;
        }

        public int getMaxNumFrames() {
            return this.maxNumFrames;
        }

        public int getMaxNumFramesPerEpisode() {
            return this.maxNumFramesPerEpisode;
        }

        public boolean isMinimalActionSet() {
            return this.minimalActionSet;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof Configuration)) {
                return false;
            }
            Configuration other = (Configuration)o;
            if (this.getRandomSeed() != other.getRandomSeed()) {
                return false;
            }
            if (Float.compare(this.getRepeatActionProbability(), other.getRepeatActionProbability()) != 0) {
                return false;
            }
            if (this.getMaxNumFrames() != other.getMaxNumFrames()) {
                return false;
            }
            if (this.getMaxNumFramesPerEpisode() != other.getMaxNumFramesPerEpisode()) {
                return false;
            }
            return this.isMinimalActionSet() == other.isMinimalActionSet();
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            result = result * 59 + this.getRandomSeed();
            result = result * 59 + Float.floatToIntBits(this.getRepeatActionProbability());
            result = result * 59 + this.getMaxNumFrames();
            result = result * 59 + this.getMaxNumFramesPerEpisode();
            result = result * 59 + (this.isMinimalActionSet() ? 79 : 97);
            return result;
        }

        public String toString() {
            return "ALEMDP.Configuration(randomSeed=" + this.getRandomSeed() + ", repeatActionProbability=" + this.getRepeatActionProbability() + ", maxNumFrames=" + this.getMaxNumFrames() + ", maxNumFramesPerEpisode=" + this.getMaxNumFramesPerEpisode() + ", minimalActionSet=" + this.isMinimalActionSet() + ")";
        }
    }
}

