import os, sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import itertools
import utils
import numpy as np

import mdp

def iterate_algorithm(model, algorithm, history, episodes=None):
    """ 
    Iterate a reinforcement learning algorithm.
    """

    x = history[-1]
    a = algorithm.play(x)
    r, y = model.sample(x, a)
    algorithm.observe(x, a, r, y)
    history[-1] = (x, a, r, y)
    history.append(y)

    if episodes != None and len(episodes) < algorithm.k:
        episodes.append((algorithm.tk, algorithm.pi))

def parse_history(model, history):
    """ Parse the history. Return a dictionary with the following entries:

    regret         The usual regret     Tg* - sum_t R_t
    pseudo regret  The pseudo regret    Tg* - sum_t r(Z_t)
    gap regret     The gap regret       sum_t Delta(Z_t)
    
    """

    info = dict()
    info["regret"] = [0]
    info["pseudo regret"] = [0]
    info["gap regret"] = [0]
    gain = model.gain()
    rewards = model.rewards()
    gaps = model.gaps()
    for x, a, r, _ in history:
        p_r = rewards[x, a]
        g   = gain[x]
        info["regret"       ].append(g    - r   + info["regret"       ][-1])
        info["pseudo regret"].append(g    - p_r + info["pseudo regret"][-1])
        info["gap regret"   ].append(gaps[x, a] + info["gap regret"   ][-1])
    return info

################################################################################

class Agent:

    def __init__(self, model):
        """ The model is given as a parameter but stays unknown to the agent.
        It is mostly for convenience - e.g. to ease the writing of the code
        when rewards are known/unknown. 
        """

        # Initializing shape
        Z = model.Z
        self.Z = set(Z.copy())
        self.S = list({s for s, _ in self.Z})
        self.A = [[] for _ in self.S]
        for s, a in self.Z:
            self.A[s].append(a)

        self.n_states = len(self.S)
        self.n_actions = [len(self.A[x]) for x in self.S]

    def reset(self, model):
        utils.unimplemented(self.reset)

    def name(self):
        utils.unimplemented(self.name)

    def observe(self, x, a, r, y):
        utils.unimplemented(self.observe)

    def play(self, x):
        """ Pick an action """
        utils.unimplemented(self.play)

################################################################################

class RandomLearner(Agent):

    def observe(self, x, a, r, y):
        pass

    def name(self):
        return "RandomLearner"

    def play(self, x):
        return np.random.choice(self.A[x])

################################################################################

if __name__ == "__main__": from matplotlib import pyplot as plt
if __name__ == "__main__": import tqdm
if __name__ == "__main__":
    n_states, n_actions = 2, 2
    T        = 100000
    model    = mdp.MDP(n_states=n_states, n_actions=n_actions)
    print(model)
    print(model.gaps())
    learner  = RandomLearner(model)
    history  = [0]
    for _ in tqdm.tqdm(range(T)):
        iterate_algorithm(model, learner, history)
    history.pop()
    alg_data = parse_history(model, history)
    X  = list(range(T+1))
    Y_reg  = alg_data["regret"]
    Y_preg = alg_data["pseudo regret"]
    Y_greg = alg_data["gap regret"]
    plt.plot(X, Y_preg, linestyle="dashed", color="blue")
    plt.plot(X, Y_greg, linestyle="dashed", color="red")
    plt.plot(X, Y_reg, linestyle="solid",  color="black", label=learner.name())
    plt.legend()
    plt.show()
