import collections

import numpy as np
from pycolab import ascii_art
from pycolab import things as plab_things
from pycolab.prefab_parts import sprites as prefab_sprites

# Actions.
# Those with a negative ID are not allowed for the agent.
ACTION_QUIT = -2
ACTION_DELAY = -1
ACTION_NORTH = 0
ACTION_SOUTH = 1
ACTION_WEST = 2
ACTION_EAST = 3


PLACEABLES = collections.OrderedDict((
    ('player', '+'),
    ('key', '0'),
))


def generate_maze(size):
    assert size >= 5
    maze = ['#' * size, '#' + ' ' * (size - 4) + '0 #']
    for _ in range(size - 4):
        maze.append('#' + ' ' * (size - 2) + '#')
    maze.extend(['#+' + ' ' * (size - 3) + '#', '#' * size])
    return maze


class KeyDrape(plab_things.Drape):
    """Drape for the key."""

    def __init__(self, curtain, character, pickup_reward):
        super(KeyDrape, self).__init__(curtain, character)
        self._pickup_reward = pickup_reward

    def update(self, actions, board, layers, backdrop, things, the_plot):
        player_position = things[PLACEABLES['player']].position
        if self.curtain[player_position]:
            the_plot.add_reward(self._pickup_reward)


class TimerSprite(plab_things.Sprite):
    """Sprite for the timer.
    The timer is in charge of stopping the game. Timer sprite should be
    placed last in the update order to make sure everything is updated before the
    chapter terminates.
    """

    def __init__(self, corner, position, character, max_frames):
        super(TimerSprite, self).__init__(corner, position, character)
        if not isinstance(max_frames, int):
            raise ValueError('max_frames must be of type integer.')
        self._max_frames = max_frames
        self._visible = False

    def update(self, actions, board, layers, backdrop, things, the_plot):
        if the_plot.frame >= self._max_frames:
            the_plot.terminate_episode()


class PlayerSprite(prefab_sprites.MazeWalker):
    """Sprite for the actor."""

    def __init__(self, corner, position, character, impassable='#'):
        super(PlayerSprite, self).__init__(
            corner, position, character, impassable=impassable,
            confined_to_board=True)

    def update(self, actions, board, layers, backdrop, things, the_plot):
        if actions == ACTION_QUIT:
            the_plot.next_chapter = None
            the_plot.terminate_episode()

        if actions == ACTION_WEST:
            status = self._west(board, the_plot)
        elif actions == ACTION_EAST:
            status = self._east(board, the_plot)
        elif actions == ACTION_NORTH:
            status = self._north(board, the_plot)
        elif actions == ACTION_SOUTH:
            status = self._south(board, the_plot)
        elif actions is None:
            return
        else:
            raise KeyError
        if status:
            # A small penalty for moving toward obstacles.
            the_plot.add_reward(-0.01)


def make_game(size):
    """Factory method for generating a new game."""
    art = generate_maze(size)
    return ascii_art.ascii_art_to_game(
        art=art,
        what_lies_beneath=' ',
        sprites={
            PLACEABLES['player']: PlayerSprite,
        },
        drapes={
            PLACEABLES['key']: ascii_art.Partial(
                KeyDrape,
                pickup_reward=1),
        },
        update_schedule=[
            PLACEABLES['player'],
            PLACEABLES['key'],
        ],
        z_order=[
            PLACEABLES['key'],
            PLACEABLES['player'],
        ],
        occlusion_in_layers=False,
    )


if __name__ == '__main__':
    rng = np.random.RandomState()
    maze = generate_maze(9)
    for row in maze:
        print(row)
