from register import *
import numpy as np
import tqdm 
import random
import time
import utils
import itertools

class MDP:

    """ The usual MDP object """

    def __init__(self, n_states=None, n_actions=None, Z=None, prior=1.0):
        """ 
        Several modes of initialization:

        - Tabular MDP: specify n_states and n_actions as scalars;
        - Finite  MDP: specify n_states and n_actions, the second a vector;
        - Finite  MDP: specify state-action pairs collection Z. 
        """

        # ID and internal related quantities
        self._id   = 0
        self._gain = None # optimal gain
        self._gaps = None # Bellman gaps
        self._pi   = None # optimal policy
        self._eps  = 1e99 # accuracy on optimal quantities
        self._mu   = None # invariant measure
        self._u    = None # bias
        self._u_vi = None # bias array of value iteration
        self._vi_id = -1   # Value Iteration ID
        self._gaps_id = -1 # Gaps ID
        
        # Initializing shape
        assert((n_states != None and n_actions != None) or Z != None)
        if n_states != None and n_actions != None:
            n_actions = np.array(n_actions)
            if len(n_actions.shape) == 0: # n_actions is vectorialized
                n_actions = [n_actions for _ in range(n_states)]
            elif len (n_actions.shape) == 1:
                pass
            else:
                raise Exception(f"Unvalid 'n_actions': {n_action}")
            self.S = list(range(n_states))
            self.A = [ list(range(n_x)) for n_x in n_actions ] 
            self.Z = set()
            for x in self.S:
                for a in self.A[x]:
                    self.Z.add((x,a))
        elif Z != None:
            self.Z = set(Z.copy())
            self.S = list({s for s, _ in self.Z})
            self.A = [[] for _ in self.S]
            for x, a in self.Z:
                self.A[x].append(a)
        else:
            utils.inaccessible(MDP.__init__)

        self.n_states = len(self.S)
        self.n_actions = [len(self.A[x]) for x in self.S]
        
        # Initializing rust data
        act_data = (ct.c_size_t * self.n_states)()
        for x in self.S: act_data[x] = self.n_actions[x]
        self._data = libmdp.MDP_new(self.n_states, act_data)

        # Initializing reward and transition structure
        self.randomize_rewards(prior)
        self.randomize_kernels(prior)

    def __del__(self):
        libmdp.MDP_free(self._data)

    def __repr__(self):
        tokens = [""]
        tokens.append("Structure:")
        tokens.append(f"| S: {self.n_states}")
        tokens.append(f"| A: {self.n_actions}")

        tokens.append("Reward and transitions:")
        rewards = self.rewards()
        kernels = self.kernels()
        for x in self.S:
            for a in self.A[x]:
                r = round(rewards[x,a], 2)
                k = np.round(np.array(kernels[x, a]), 3)
                tokens.append(f"| r({x},{a}): {r}, p(-|{x},{a}): {k}")

        tokens.append("Information:")
        u, g, pi = self.value_iteration(eps=1e-6)
        sp  = max(u) - min(u)
        mu  = self.invariant_measure(pi, eps=1e-6)
        dot = sum(u_x * mu_x for u_x, mu_x in zip(u, mu))
        h   = [u_x - dot for u_x in u]
        tokens.append(f"| pi*   : {np.array(pi)}")
        tokens.append(f"| mu*   : {np.round(np.array(mu), 3)}")
        tokens.append(f"| g *   : {np.round(np.array(g), 3)}")
        tokens.append(f"| h *   : {np.round(np.array(h), 3)}")
        tokens.append(f"| hvi   : {np.round(np.array(u), 3)}")
        tokens.append(f"| sp(h*): {sp:.3f}")
        tokens.append(f"| gaps  : {self.gaps()}")

        max_len = max(len(token) for token in tokens)
        tokens[0] = "=" * max_len
        tokens.append("=" * max_len)
        return "\n".join(tokens)

    def update_id(self):
        self._id += 1
    def get_id(self):
        return self._id

    def randomize_rewards(self, prior):
        n_states = self.n_states
        for x, a in self.Z:
            alpha, beta = prior, prior
            mu = np.random.beta(alpha, beta)
            libmdp.MDP_set_reward(self._data, x, a, mu)
        self.update_id()

    def randomize_kernels(self, prior):
        n_states = self.n_states
        alpha = [prior for _ in range(self.n_states)]
        for x, a in self.Z:
            kernel      = np.random.dirichlet(alpha)
            kernel_data = (ct.c_double * n_states)()
            for y in range(n_states): kernel_data[y] = kernel[y]
            libmdp.MDP_set_kernel(self._data, x, a, kernel_data)
        self.update_id()

    ### Setters

    def set_reward(self, x, a, mean):
        assert((x, a) in self.Z)
        libmdp.MDP_set_reward(self._data, x, a, mean)
        self.update_id()

    def set_kernel(self, x, a, kernel):
        assert((x, a) in self.Z)
        kernel_data = (ct.c_double * self.n_states)()
        for y in range(self.n_states): kernel_data[y] = kernel[y]
        libmdp.MDP_set_kernel(self._data, x, a, kernel_data)
        self.update_id()

    def set_rewards(self, means):
        for x, a in self.Z:
            mean = means[x, a]
            self.set_reward(x, a, mean)
        self.update_id()

    def set_kernels(self, kernels):
        for x, a in self.Z:
            kernel = kernels[x, a]
            self.set_kernel(x, a, kernel)
        self.update_id()

    ### Getters

    def reward(self, x, a):
        r = ct.c_double()
        libmdp.MDP_get_reward_into(self._data, x, a, ct.byref(r))
        return r.value
    
    def rewards(self):
        return { (x, a): self.reward(x, a) for (x, a) in self.Z }

    def kernel(self, x, a):
        kernel_data = (ct.c_double * self.n_states)()
        libmdp.MDP_get_kernel_into(self._data, x, a, kernel_data) 
        return [kernel_data[i] for i in range(self.n_states)]

    def kernels(self):
        return { (x, a): self.kernel(x, a) for (x, a) in self.Z }

    ### Sampling

    def sample(self, x, a):
        """ Sample a transition """ 
        y = ct.c_size_t()
        r = ct.c_double()
        libmdp.MDP_sample(self._data, x, a, ct.byref(r), ct.byref(y))
        return r.value, y.value

    def sample_path(self, x_init, pi_vec, length):
        """ Sample a path from a policy vector """
        path = []
        x    = x_init
        for n in tqdm.tqdm(range(length)):
            a      = pi_vec[x]
            rew, y = self.sample(x, a)
            path.append((x, a, rew, y))
            x = y
        return path

    ### Info

    def gain(self, eps=1e-6):
        """ Returns the optimal gain """ 
        _, g, _ = self.value_iteration(eps=eps)
        return g
    
    def gain_of(self, pi, eps=1e-6):
        """ Returns the gain of a given policy """
        n = self.n_states
        pi_mdp = MDP(n_states=n, n_actions=[1]*n)
        for x in range(self.n_states):
            mea_x = self.reward(x, pi[x])
            ker_x = self.kernel(x, pi[x])
            pi_mdp.set_reward(x, 0, mea_x)
            pi_mdp.set_kernel(x, 0, ker_x)
        return pi_mdp.gain(eps)

    def bias(self, eps=1e-6):
        u, _, _ = self.value_iteration(eps=eps)
        return u
    
    def invariant_measure(self, pi, eps=1e-6):
        """ Return the invariant measure of a policy in vector format """
        n_states = self.n_states
        pi_data = (ct.c_size_t * n_states)()
        mu_data = (ct.c_double * n_states)()
        for x in range(n_states): pi_data[x] = pi[x]
        libmdp.MDP_invariant_measure(self._data, pi_data, eps, mu_data)
        return [mu_data[x] for x in range(n_states)]

    def value_iteration(self, u_init=None, eps=1e-6):
        """ Run Value Iteration (VI) """

        # Memorisation - Only run VI when accuracy is increased *or* if the
        # model has been modified (i.e., VI data is out-of-date). 
        if eps >= self._eps and self._vi_id == self.get_id(): 
            return self._u_vi, self._gain, self._pi
        self._eps   = eps
        self._vi_id = self.get_id()
        
        # Value Iteration Algorithm
        n_states = self.n_states
        u_data  = (ct.c_double * n_states)()
        g_data  = (ct.c_double * n_states)()
        pi_data = (ct.c_size_t * n_states)()
        libmdp.MDP_value_iteration(self._data, eps, u_data, g_data, pi_data)
        
        # Postprocessing: Standard  data
        u  = [ u_data[i] for i in range(n_states)]
        g  = [ g_data[i] for i in range(n_states)]
        pi = [pi_data[i] for i in range(n_states)]
        self._u_vi = u
        self._gain = g
        self._pi   = pi

        return u, g, pi
    
    def gaps(self, eps=1e-6):
        """ Compute the Bellman gaps (bias computed with VI) """

        # Memorisation - Only run VI when accuracy is increased *or* if the
        # model has been modified (i.e., VI data is out-of-date). 
        if eps >= self._eps and self._gaps_id == self.get_id(): 
            return self._gaps
        self._gaps_id = self.get_id()
        
        # Computing gaps
        u, g, pi = self.value_iteration(eps=eps)
        mu  = self.invariant_measure(pi, eps=1e-6)
        dot = sum(u_x * mu_x for u_x, mu_x in zip(u, mu))
        h   = [u_x - dot for u_x in u]
        gaps     = {}
        for x, a in self.Z:
            r, p  = self.reward(x, a), self.kernel(x, a)
            q_xa  = r + sum(p_y * h_y for p_y, h_y in zip(p, h))
            gaps[x, a] = max(0.0, g[x] + h[x] - q_xa - eps)
        self._gaps = gaps.copy()

        return gaps

### Example code

if __name__ == "__main__":
    
    S, A = 5, 2
    T = 100000
    model = MDP(n_states=S, n_actions=A)
    print(model)

    u, g_opt, pi_opt = model.value_iteration()

    hist  = model.sample_path(model.S[0], pi_opt, length=T)
    g_emp = sum(r for x, a, r, y in hist) / len(hist)
    
    print(pi_opt, g_emp, g_opt)
    print(model.gaps())

