import numpy as np
import matplotlib.pyplot as plt
import os
import pdb
import random
import time
import types

# We can use direct parameterization xi[s,a,s']=p(s'|s,a) and theta[s,a]=pi(a|s). 


#Default values
num_states_default = 5
num_actions_default = 3


def set_hyps(a,a_default):
    if a is None:
        return a_default
    else:
        return a

def set_seed(seed=1):
    if seed is not None:
        np.random.seed(seed)    
        random.seed(seed)

def env_setup(seed_init=1,state_space=None,action_space=None,rho=None,Psi=None,xi0=None,xi_radius=0.01,cost=None,gamma=0.95):
    #The L2 ambiguity set contains all P(s'|s,a)=<Psi[:,a,s'], xi[s,:]> such that 
    #||xi[s,:]-xi0[s,:]||_2<=xi_radius[s] for all s
    #sum_{s'} <Psi[:,a,s'], xi0[s,:]>  =  sum_{s'} <Psi[:,a,s'], xi[s,:]>  =1
    
    #If Psi=None, then this is tabular case where xi[s,:] parameterizes P(s'|s,a) by reshaping a |A| X |S| matrix.
    #  In this simplest case, the ambiguity set contains all p such that ||p[:|s,:]-p0[:|s,:]||<=xi_radius[s] for all s

    env_dict = {}
    set_seed(seed_init)
    env_dict['seed_init'] = seed_init
    
    env_dict['state_space'] = set_hyps(a=state_space,a_default=range(num_states_default))
    env_dict['action_space'] = set_hyps(a=action_space,a_default=range(num_actions_default))  
    env_dict['num_states'] = len(env_dict['state_space'])
    env_dict['num_actions'] = len(env_dict['action_space'])
    
    pi = np.ones((env_dict['num_states'],env_dict['num_actions']))/env_dict['num_actions'] # default pi
    env_dict['pi'] = pi
    
    if rho is None:
        rho = 10 + np.random.normal(size=(env_dict['num_states']))
    else:
        if isinstance(rho, list):
            rho = np.array(rho)
        assert rho.size == env_dict['num_states'], "rho should have "+str(env_dict['num_states'])+" entries."
    rho = np.abs(rho).reshape(env_dict['num_states'])
    env_dict['rho'] = rho/rho.sum()
    
    # The L2 ambiguity set contains all P(s'|s,a) = <Psi[:,a,s'], xi[s,:]> such that 
    # ||xi[s,:]-xi0[s,:]||_2<=xi_radius for all s
    # sum_{s'} <Psi[:,a,s'], xi0[s,:]>  =  sum_{s'} <Psi[:,a,s'], xi[s,:]>  =1
    
    # If Psi=None, then this is tabular case where xi[s,a,s'] parameterizes P(s'|s,a)
    # Equivalently, Psi[:,a,s'] can be seen as one-hot with (a,s')-th entry being
    warn = "xi_radius should be a float, an int or an np.array with shape (num_states)"
    if type(xi_radius) is np.ndarray:
        xi_radius = xi_radius.reshape(-1)
        assert xi_radius.shape == (num_states_default,), warn
        env_dict['xi_radius'] = xi_radius.copy()
    elif type(xi_radius) in [float,int]:
        env_dict['xi_radius'] = xi_radius * np.ones(env_dict['num_states'])
    else:
        assert False, warn

    if Psi is None: # what is Psi?? a kind a parameterization of P transition function
        xi_shape = (env_dict['num_states'],env_dict['num_actions'],env_dict['num_states'])
        if xi0 is None:
            env_dict['xi0'] = np.abs(10+np.random.normal(size=xi_shape))   #xi0[s,a,s']
            env_dict['xi0'] = env_dict['xi0']/np.sum(env_dict['xi0'],axis=2,keepdims=True)
        else:
            assert type(xi0) is np.ndarray, "xi0 should be an np.ndarray with shape (num_states,num_actions,num_states)" 
            assert xi0.shape == xi_shape, \
                "xi0 should have shape: (num_states,num_actions,num_states)"
            xi0=np.abs(xi0)
            env_dict['xi0']=xi0/np.sum(xi0,axis=2,keepdims=True)
    else: 
        warn = "Psi should be None or an np.ndarray with shape (dp,num_actions,num_states)"
        assert type(Psi) is np.ndarray, warn
        assert len(Psi.shape)==3, warn
        env_dict['dp'] = Psi.shape[0]
        assert Psi.shape == (env_dict['dp'],env_dict['num_actions'],env_dict['num_states']), warn
        Psi_sums = Psi.sum(axis=2)
        assert np.linalg.matrix_rank(Psi_sums) < env_dict['dp'], "The dp*|A| matrix sum_{s'} Psi(:,:,s') should have rank<dp."
        env_dict['Psi'] = Psi.copy()
        # least square solution of a@x=b
        env_dict['Psi_proj'] = np.linalg.lstsq(Psi_sums, np.identity(env_dict['dp']), rcond=None)[0] 
        env_dict['Psi_proj'] = Psi_sums.dot(env_dict['Psi_proj'])

        warn="xi0 should be None or an np.ndarray with shape: (num_states,dp)"
        xi_shape=(env_dict['num_states'],env_dict['dp'])
        if xi0 is None: 
            xi0=np.abs(10+np.random.normal(size=xi_shape))
            env_dict['xi0']=xi0/np.sum(xi0,axis=2,keepdims=True)
        else: 
            assert type(xi0) is np.ndarray, warn
            assert xi0.shape == xi_shape, warn
        
        tmp = (Psi_sums.reshape(1,env_dict['dp'],env_dict['num_actions']) \
            *(xi0.reshape(env_dict['num_states'],env_dict['dp'],1))).sum(axis=1)-1 
        assert xi0.min()>=0, "All entries of xi0 must be non-negative."
        assert np.abs(tmp).max()<1e-14, "sum_j Psi[j,a,s']*xi0[s,j] should be 1."  
        env_dict['xi0']=xi0.copy()
    
    env_dict['xi'] = env_dict['xi0'] # ||xi[s,:,:]-xi0[s,:,:]||_2 <= xi_radius[s] for all s
    
    env_dict['gamma'] = gamma
    
    cost_shape = (env_dict['num_states'],env_dict['num_actions'],env_dict['num_states'])   
    if cost is None:
        env_dict['cost'] = np.random.uniform(size=cost_shape,low=0,high=1)
    else:
        warn="cost should be either None or an np.array with shape (num_states,num_actions,num_states)"
        assert type(cost) is np.ndarray, warn 
        assert cost.shape == cost_shape, warn
        env_dict['cost']=cost.copy()
    return env_dict

def get_s(state_space,xi,s,a,num=1,Psi=None): #Get num state samples from p_{xi}(:|s,a)
    if Psi is None: 
        s_next=np.random.choice(state_space, size=num, p=xi[s,a,:], replace=True)
    else:
        s_next=np.random.choice(state_space, size=num, p=xi[s].dot(Psi[:,a,:]), replace=True)
    if num==1:
        return s_next[0]
    return s_next

def get_a(action_space, pi, s, num=1, Psi=None):
    #  a ~ pi(.|s) with probability of pi[s]
    if Psi is None: 
        a_next = np.random.choice(action_space, size=num, p=pi[s], replace=True)
    else:
        # s_next=np.random.choice(state_space, size=num, p=xi[s].dot(Psi[:,a,:]), replace=True)
        pass
    if num == 1:
        return a_next[0]
    return a_next

def get_transP_s2s(pi,xi,Psi=None): 
    '''
    Obtain state transition distribution 
    transP_s2s(s,s') = sum_a p_{xi}(s'|s,a) * pi(a|s) = sum( xi[s,a,s'] pi[s,a] )
    '''
    num_states,num_actions=pi.shape
    if Psi is None:
        return (xi*(pi.reshape(num_states,num_actions,1))).sum(axis=1)
    p = get_p(xi,Psi)
    return (p*(pi.reshape(num_states,num_actions,1))).sum(axis=1)

def stationary_dist(transP_s2s):  #Stationary state distribution corresponding to transP_s2s(s,s')
    evals, evecs = np.linalg.eig(transP_s2s.T)  #P.T*evecs=evecs*np.diag(evals)
    evec1 = evecs[:, np.isclose(evals, 1)]
    evec1 = np.abs(evec1[:, 0])
    stationary = evec1 / evec1.sum()
    return stationary.real

def occupation(pi, xi, rho, gamma, Psi=None): 
    #Exact occupation measure lambda(s,a)=lambda(s)*pi(a|s)
    #We can prove that lambda(s) is the stationary distribution under transition kernel gamma*P(s'|s,a)+(1-gamma)*rho(s),
    # where rho denotes the initial state distribution.
    p = get_p(xi,Psi)
    p2 = gamma*p + (1-gamma)*rho.reshape((1,1,-1))
    p_s2s = get_transP_s2s(pi,p2,Psi=None)
    # lambda_out = stationary_dist(p_s2s).reshape(-1) * pi # modified
    lambda_out = stationary_dist(p_s2s).reshape(-1,1) * pi # modified
    return lambda_out

def Vrho_func(pi,xi,rho,cost,gamma,Psi=None):  
    #J_{rho}(pi, p_{xi}) is the value function under policy pi, transition kernel p_{xi} and initial state distribution rho. 
    num_states, num_actions=pi.shape
    p=get_p(xi,Psi)
    occup=occupation(pi,p,rho,gamma,Psi=None)
    return ((((p*cost).sum(axis=2))*pi).sum(axis=1)*occup).sum()/(1-gamma)

def V_func(pi,xi,cost,gamma,Psi=None):
    #Return vector V where V[s] is the value function with state s. 
    p = get_p(xi,Psi)
    p_s2s = get_transP_s2s(pi,p,Psi=None)
    # Ecost = ((p*cost).sum(axis=2)*pi).sum(axis=1) # TODO: should modify cost[]
    Ecost = np.sum(pi * cost, axis=1)
    A = np.identity(p_s2s.shape[0]) - gamma*p_s2s
    # solving the bellman function by (I - gamma*P) V = R 
    Vpi = np.linalg.solve(A, Ecost)
    # Vpi = Vpi.reshape(-1,1)
    return Vpi

def Q_func(pi,xi,cost,gamma,Psi=None,V=None):
    #Return matrix Q where Q[s,a] is the Q function value at (s,a)
    p = get_p(xi,Psi)
    if V is None:
        V = V_func(pi,p,cost,gamma,Psi=None)
    # Q = (p*(cost+gamma*V.reshape(1,1,-1))).sum(axis=2) 
    Q = cost + gamma * np.sum(p * V, axis=2)
    return Q
    
def proj_L2_xi(xi,xi0,xi_radius,Psi=None,Psi_proj=None):
    '''
    xi is the parameter for transition function P(s, a, s')
    This function is in probability simplex forall given a and s
    reference  pg.21-24 in https://www.cs.unh.edu/~mpetrik/pub/tutorials/robustrl/dlrl-extended.pdf

    Case I: no Psi, 
    Case II: with Psi
    '''

    if Psi is None:   
        u = xi - xi0
        # For each s, project xi[s,:,:] onto {p: p[s,a,.] is probability vector AND ||xi[s,:,:]-xi0[s,:,:]||_2<=xi_radius[s]}
        u -= u.mean(axis=2,keepdims=True)
        r = np.sqrt((u*u).sum(axis=(1,2)))
        # for the index of r larger thatn the radius, pull them back to the range
        index = (r > xi_radius)
        u[index] *= (xi_radius[index]/r[index]).reshape(-1,1,1)
    else:
        vec = xi - xi0
        if Psi_proj is None:   
            Psi_sums = Psi.sum(axis=2)  #Psi_sums[j,a] for j-th feature dimension
            u = np.linalg.lstsq(Psi_sums, vec.T, rcond=None)[0]
            u = Psi_sums.dot(u)
        else: 
            u = Psi_proj.dot(vec.T)
        u = vec - u.T
        r = np.sqrt((u*u).sum(axis=1))
        index = (r>xi_radius)
        u[index] *= (xi_radius[index]/r[index]).reshape(-1,1)
    return xi0 + u
    
def proj_Pr(y):
    '''
    Project vector y into probability space:
    Algorithm in https://arxiv.org/pdf/1309.1541.pdf 
    '''  
    
    y = y.reshape(-1)
    u = np.flip(np.sort(y))
    D = u.shape[0]
    sum_remain = 1
    for j in range(D):
        sum_remain -= u[j]
        lambda_now = sum_remain/(j+1)
        if lambda_now + u[j] > 0:
            lambda_save = lambda_now
    x = y + lambda_save
    x[x<0] = 0 # some values can be set zero! (may be divided next)
    return x

def proj_L2_pi(pi):
    '''
    Project policy parameter to the probability simplex forall given states
    For simplex: https://en.wikipedia.org/wiki/Simplex
    '''
    num_states, num_actions = pi.shape
    pi2 = pi.copy()
    for s in range(num_states):
        pi2[s] = proj_Pr(pi[s])
    return pi2

def DRPG(env_dict,Tp,Tpi,eta,beta,pi0=None,p0=None,num_Viter=1000,is_save_p=False,is_save_pi=False,is_print=True):   
    #DRPG algorithm in Wang, Qiuhao, Chin Pang Ho, and Marek Petrik. 
    # "Policy gradient in robust mdps with global convergence guarantee." 
    # International Conference on Machine Learning. PMLR, 2023.
    shape = Tpi + 1
    results = {'J_max':np.zeros(shape),'J':np.zeros(shape),'p_iters':np.zeros(shape),'pi_iters':np.zeros(shape),\
             'Fp':np.zeros(shape),'Tp':Tp,'Tpi':Tpi,'eta':eta,'beta':beta,'num_Viter':num_Viter}
    if p0 is None:
        p0=env_dict['xi0'].copy()
    else:
        assert p0.shape==(env_dict['num_states'],env_dict['num_actions'],env_dict['num_states']), \
            "p0 should be None or an np array with shape (env_dict['num_states'],env_dict['num_actions'],env_dict['num_states'])"

    if pi0 is None:
        pi=np.ones((env_dict['num_states'],env_dict['num_actions']))/env_dict['num_actions']
    else:
        assert pi0.shape==(env_dict['num_states'],env_dict['num_actions']), \
            "pi0 should be None or an np array with shape (env_dict['num_states'],env_dict['num_actions'])"
        pi=pi0.copy()
    results['J'][0]=V_func(pi,p0,env_dict['cost'],env_dict['gamma'],Psi=None).dot(env_dict['rho'])
    results['J_max'][0]=V_robust_iter(pi,env_dict['xi0'],env_dict['cost'],env_dict['xi_radius'],env_dict['gamma'],\
                                num_Viter=num_Viter,Psi=None,Psi_proj=None,V0=None,is_print=False)[0].dot(env_dict['rho'])
    results['Fp'][0]=Vp_minpi(p0,env_dict['cost'],env_dict['gamma'],num_Viter=num_Viter,V0=None,is_print=False).dot(env_dict['rho'])
    results['p_iters'][0]=0
    results['pi_iters'][0]=0
    if is_save_pi:
        results['pi']=np.zeros((shape,env_dict['num_states'],env_dict['num_actions']))
        results['pi'][0]=pi.copy()
    if is_save_p:
        results['p']=np.zeros((shape,env_dict['num_states'],env_dict['num_actions'],env_dict['num_states'],))
        results['p'][0]=p0.copy()
    eta2=eta/(1-env_dict['gamma'])
    beta2=beta/(1-env_dict['gamma'])
    # Jmin_pi=np.Inf
    for t in range(Tpi):
        Tp_now=Tp(t)
        pi2=pi.reshape(env_dict['num_states'],env_dict['num_actions'],1)
        p=p0.copy()
        # Jmax_p=-np.Inf
        for k in range(Tp_now):
            ds=occupation(pi,p,env_dict['rho'],env_dict['gamma'],Psi=None).reshape(env_dict['num_states'],1,1)
            V=V_func(pi,p,env_dict['cost'],env_dict['gamma'],Psi=None)
            Jnow=V.dot(env_dict['rho'])
            # if Jnow>Jmax_p:
            #     Jmax_p=Jnow
            #     Vmax_p=V.copy()
            #     p_out=p.copy()
            p=proj_L2_xi(p+beta2*ds*pi2*(env_dict['cost']+env_dict['gamma']*V.reshape(1,1,-1)),\
                         env_dict['xi0'],env_dict['xi_radius'],Psi=None,Psi_proj=None)
                
        V=V_func(pi,p,env_dict['cost'],env_dict['gamma'],Psi=None)
        if is_print:
            print("Iteration "+str(t)+": J_max="+str(results['J_max'][t])+", J="+str(results['J'][t])\
                  +", Fp="+str(results['Fp'][t]))
        ds=occupation(pi,p,env_dict['rho'],env_dict['gamma'],Psi=None).reshape(env_dict['num_states'],1)
        Q=Q_func(pi,p,env_dict['cost'],env_dict['gamma'],Psi=None,V=V)
        pi=proj_L2_pi(pi-eta2*ds*Q)
        
        V=V_func(pi,p,env_dict['cost'],env_dict['gamma'],Psi=None)
        Jnow=V.dot(env_dict['rho'])
        results['J'][t+1]=Jnow
        results['J_max'][t+1]=V_robust_iter(pi,env_dict['xi0'],env_dict['cost'],env_dict['xi_radius'],env_dict['gamma'],\
                                    num_Viter=num_Viter,Psi=None,Psi_proj=None,V0=None,is_print=False)[0].dot(env_dict['rho'])
        results['p_iters'][t+1]=results['p_iters'][t]+Tp_now
        results['pi_iters'][t+1]=t+1
        results['Fp'][t+1]=Vp_minpi(p,env_dict['cost'],env_dict['gamma'],num_Viter=num_Viter,V0=None,is_print=False).dot(env_dict['rho'])
        if is_save_pi:
            results['pi'][t+1]=pi.copy()
        if is_save_p:
            results['p'][t+1]=p.copy()
        # if Jnow<Jmax_p:
        #     p=p_out.copy()
        #     V=Vmax_p.copy()
        #     Jnow=Jmax_p
        # if Jnow<Jmin_pi: 
        #     Jmin_pi=Jnow
        #     results['p_opt']=p.copy()
        #     results['pi_opt']=pi.copy()
        #     results['t_opt']=t
    return results




if __name__ == '__main__':
    print('start')

################ Below is not useful for robust RL with general utility ##################





def V_robust_iter(pi,xi0,cost,xi_radius,gamma,num_Viter=1000,xi_norm_cutoff=1e-8,Psi=None,Psi_proj=None,V0=None,is_print=False):   
    #Compute max_p V(pi,p) of certain policy pi via value iteration
    num_states,num_actions=pi.shape
    if V0 is None:
        V=np.zeros(num_states)
    else:
        V=V0.copy()
    for t in range(num_Viter):
        V_past=V.copy()
        if Psi is None:
            vec=pi.reshape(num_states,num_actions,1)*(cost+gamma*V.reshape(1,1,-1))  #vec[s,a,s']=pi(a|s)*[c(s,a,s')+gamma*V(s')]
            xi=vec-(vec.mean(axis=2,keepdims=True))
            xi_norm=np.sqrt((xi*xi).sum(axis=(1,2)))
            coeff=xi_radius/xi_norm
            coeff[xi_norm<=xi_norm_cutoff]=0
            # p*=coeff
            # p-=p.mean(axis=2,keepdims=True)
            # p_norm=np.sqrt((p*p).sum(axis=(1,2)))
            # coeff=xi_radius/p_norm
            # coeff[p_norm==0]=0
            xi=xi0+(xi*coeff.reshape(num_states,1,1))
            V=(xi*vec).sum(axis=(1,2))
        else:
            dp=Psi.shape[0]
            vec=((cost+gamma*V.reshape(1,1,-1)).reshape(1,num_states,num_actions,num_states)*Psi.reshape(dp,1,num_actions,num_states)).sum(axis=3)  
                #vec[j,s,a] for j-th feature dimension
            
            vec=(pi.reshape(1,num_states,num_actions)*vec).sum(axis=2)  #vec[:,s]=sum_{a,s'} pi(a|s)*Psi(:,a,s')*[c(s,a,s')+gamma*V(s')]
            
            #Will project vec[:,s] onto {Psi_sums[:,a]: all actions a} for every s
            if Psi_proj is None:   
                Psi_sums=Psi.sum(axis=2)  #Psi_sums[j,a] for j-th feature dimension 
                xi=np.linalg.lstsq(Psi_sums, vec, rcond=None)[0]
                xi=Psi_sums.dot(xi)
            else: 
                xi=Psi_proj.dot(vec)
            xi=(vec-xi).T
            xi_norm=np.sqrt((xi*xi).sum(axis=1))
            coeff=xi_radius/xi_norm
            coeff[xi_norm<=xi_norm_cutoff]=0
            xi=(xi0+(xi*coeff.reshape(num_states,1)))
            V=(vec.T*xi).sum(axis=1)
        if is_print:
            print('Value iteration '+str(t)+': ||V_{t+1}-V_t||_{infty}='+str(np.abs(V_past-V).max()))        
    return V, xi

def Vp_minpi(p,cost,gamma,num_Viter=1000,V0=None,is_print=False):   
    #Compute max_pi V(pi,p_xi) of certain transition kernel parameter xi via value iteration
    #if Psi is None, then use direct parameterization p_xi=xi. 
    num_states=p.shape[0]
    num_actions=p.shape[1]
    if V0 is None:
        V=np.zeros(num_states)
    else:
        V=V0.copy()
    for t in range(num_Viter):
        V_past=V.copy()
        Q=(p*(cost+gamma*V.reshape(1,1,-1))).sum(axis=2) 
        V=Q.min(axis=1)
        if is_print:
            print('Value iteration '+str(t)+': ||V_{t+1}-V_t||_{infty}='+str(np.abs(V_past-V).max()))        
    return V

def get_p(xi,Psi=None):
    if Psi is None:
        return xi
    dp, num_actions, num_states=Psi.shape
    return (Psi.reshape(1,dp,num_actions,num_states)*xi.reshape(num_states,dp,1,1)).sum(axis=1)

def findP_CPI(pi,Pprime,p0,p_radius,cost,rho,gamma,p_norm_cutoff=1e-8):
    #Algorithm 2 of
    # Li, M., Sutter, T., and Kuhn, D. (2023b). Policy gradient algorithms for 
    # robust mdps with non-rectangular uncertainty sets. ArXiv:2305.19004. 
    num_states,num_actions=pi.shape
    ds=occupation(pi,Pprime,rho,gamma,Psi=None).reshape(num_states,1,1)
    V=V_func(pi,Pprime,cost,gamma,Psi=None)
    g=(ds/(1-gamma))*pi.reshape(num_states,num_actions,1)*(cost+gamma*V.reshape(1,1,-1))
    p=g-(g.mean(axis=2,keepdims=True))
    p_norm=np.sqrt((p*p).sum(axis=(1,2)))
    coeff=p_radius/p_norm
    coeff[p_norm<=p_norm_cutoff]=0
    return p0+(p*coeff.reshape(num_states,1,1))



