from RobustRL_utils import *
from utils import *
import matplotlib.pyplot as plt
from datetime import datetime
from tqdm import tqdm

def test_bc_pseg_plus(num_states, num_actions, K=50, alpha=1/48, beta=0.2, showfigs=['steps', 'complexity'], save_fig=True):
    '''
    Test environment for Algorithm 6 for project gradient descent algorithm (updated!)
    '''
    seed_init = np.random.randint(0,5)
    
    state_space = range(0,num_states)
    action_space = range(0,num_actions)
    env_dict = env_setup(seed_init,state_space,action_space,rho=None,Psi=None,xi0=None,xi_radius=0.01,cost=None,gamma=0.95)
    
    trial_nums = {'lambda':100, 'theta':100, 'xi': 100} # m_lambda, m_theta, m_xi
    trial_lens = {'lambda':50, 'theta':50, 'xi':50} # H_lambda, H_theta, H_xi
    

    pi = env_dict['pi']
    xi = env_dict['xi']
    rho = env_dict['rho']
    gamma = env_dict['gamma']
    

    alphas = np.ones(shape=(K,1)) * alpha 
    z_init = np.concatenate((env_dict['pi'].flatten(), env_dict['xi'].flatten())) # use the parameter in environment as initializationTheta x Xi
    h_init = np.concatenate((np.random.randn(num_states, num_actions).flatten(), np.random.randn(num_states, num_actions, num_states).flatten())) # Theta x Xi
    z_rand_out, norm1s = obtain_bc_pseg_plus(env_dict, num_states, num_actions, K, alphas, beta, z_init, h_init, trial_nums, trial_lens)
    
    sample_size = np.sum([trial_nums[k]*trial_lens[k] for k in trial_nums])
    complexs = 3 * sample_size * np.array(range(K))  
    return norm1s, complexs

def compute_exact_norm(pi, xi, rho, gamma, state_space):    
    lambda_exact = occupation(pi,xi,rho,gamma,Psi=None) # shape of S x A
    cost_exact = get_cost_from_grad(gamma, state_space, lambda_exact)
    V_theta, V_xi = get_exact_gradient(pi, xi, cost_exact, gamma, lambda_exact, Psi=None, V=None)
    
    Fz_flatten = np.concatenate((V_theta.flatten(), -V_xi.flatten()))
    # print('Fz_flatten.shape: {}'.format(Fz_flatten.shape))
    return Fz_flatten

def project_policy_gradient_decent(env_dict, gamma, pi, xi, alpha, iter_num,  trial_nums, trial_lens):
    '''Algorithm 3 Projected Policy Gradient Descent Algorithm'''
    num_states, num_actions = env_dict['num_states'], env_dict['num_actions']
    state_space = env_dict['state_space']

    m_lambda, m_theta = trial_nums['lambda'], trial_nums['theta'] 
    H_lambda, H_theta = trial_lens['lambda'], trial_lens['theta']

    # theta_arr = np.zeros(iter_num, num_states, num_actions) # theta shape liske SxA
    # theta_arr[0] = pi
    theta = pi

    for t in range(iter_num):
        lambda_trials = generate_trials(env_dict, m_lambda, H_lambda, state_init=None)
        # lambda_hat = get_lambda_stochastic(env_dict, lambda_trials)
        lambda_hat = get_lambda_stochastic(num_states, num_actions, gamma,  lambda_trials)
        costs_hat = get_cost_from_grad(gamma, state_space, lambda_hat)

        theta_trials = generate_trials(env_dict, m_theta, H_theta, state_init=None)
        grad_theta = get_stochastic_grad_theta(num_states, num_actions, gamma, pi, theta_trials, costs_hat)
        
        grad_theta = grad_theta.reshape(num_states, num_actions)
        if t >= iter_num - 1:
            break
        theta = proj_L2_pi(theta - alpha * grad_theta)
    
    return theta # size in SxA

def algo1_old(num_states, num_actions, K=50, K_prime=50, T=50, L=10, alpha=1/48, beta=0.2):
    '''
    Algorithm 1 Approximate Gradient Projection Ascent Algorithm for Convex Utility
    '''
    seed_init = np.random.randint(0,5)
    state_space = range(0,num_states)
    action_space = range(0,num_actions)
    env_dict = env_setup(seed_init,state_space,action_space,rho=None,Psi=None,xi0=None,xi_radius=0.01,cost=None,gamma=0.95)
    
    trial_nums = {'lambda':10, 'theta':10, 'xi': 10} # m_lambda, m_theta, m_xi
    trial_lens = {'lambda':100, 'theta':100, 'xi':100} # H_lambda, H_theta, H_xi
    
    m_lambda, m_theta, m_xi = trial_nums['lambda'], trial_nums['theta'], trial_nums['xi']
    H_lambda, H_theta, H_xi = trial_lens['lambda'], trial_lens['theta'], trial_lens['xi']

    pi = env_dict['pi']
    xi = env_dict['xi']
    rho = env_dict['rho']
    gamma = env_dict['gamma']
    xi_radius = env_dict['xi_radius']
    xi0 = env_dict['xi0']
    T_dash = 20

    print = tqdm.write
    norms = np.zeros(K+K_prime)
    # PHASE I (begin)
    print('Start Phase I....')
    complexs = np.zeros(K+K_prime)
    
    sample_cnt = 0
    for k in tqdm(range(K)):
        pi = project_policy_gradient_decent(env_dict, gamma, pi, xi, alpha, T_dash,  trial_nums, trial_lens)
        sample_cnt += m_lambda*H_lambda + m_theta*H_theta
        env_dict['pi'] = pi
        
        xi_trials = generate_trials(env_dict, m_xi, H_xi, state_init=None)

        lambda_trials = generate_trials(env_dict, m_lambda, H_lambda, state_init=None)
        lambda_hat = get_lambda_stochastic(num_states, num_actions, gamma,  lambda_trials)
        cost_hat = get_cost_from_grad(gamma, state_space, lambda_hat)
        
        
        grad_xi = get_stochastic_grad_xi(num_states, num_actions, gamma, xi=env_dict['xi'], trials=xi_trials, costs_hat=cost_hat)
        sample_cnt += m_lambda*H_lambda + m_xi*H_xi
        xi_prev = xi + beta*grad_xi
        xi = proj_L2_xi(xi_prev,xi0,xi_radius,Psi=None,Psi_proj=None)
        env_dict['xi'] = xi
        
        barz = np.concatenate([pi.flatten(), xi.flatten()])
        F_barz = obtain_exact_gradient(env_dict)
        norm = np.linalg.norm(z_proj(env_dict, barz + beta * F_barz, num_states, num_actions, check=False) - barz)/beta
        norms[k] = norm
        complexs[k] = sample_cnt
        print(f'Phase 1: Iter {k}/{K}, norm={norm:.6f}')
    xi_tilde = xi
    
    # PHASE II
    trial_nums2 = {'lambda':10, 'theta':10, 'xi': 10} # m_lambda, m_theta, m_xi
    trial_lens2 = {'lambda':100, 'theta':100, 'xi':100} # H_lambda, H_theta, H_xi
    
    m_lambda2, m_theta2, m_xi2 = trial_nums2['lambda'], trial_nums2['theta'], trial_nums2['xi']
    H_lambda2, H_theta2, H_xi2 = trial_lens2['lambda'], trial_lens2['theta'], trial_lens2['xi']
    
    trial_nums3 = {'lambda':10, 'theta':10, 'xi': 10} # m_lambda, m_theta, m_xi
    trial_lens3 = {'lambda':100, 'theta':100, 'xi':100} # H_lambda, H_theta, H_xi
    
    m_lambda3, m_theta3, m_xi3 = trial_nums3['lambda'], trial_nums3['theta'], trial_nums3['xi']
    H_lambda3, H_theta3, H_xi3 = trial_lens3['lambda'], trial_lens3['theta'], trial_lens3['xi']
    thetas_tilde = np.zeros(shape=(K_prime, num_states, num_actions))
    xis_tilde = np.zeros(shape=(K_prime, num_states, num_actions, num_states))
    
    print('Start Phase II....')
    for k in tqdm(range(K_prime)):
        for t in range(T):
            lambda_trials = generate_trials(env_dict, m_lambda2, H_lambda2, state_init=None)
            lambda_hat = get_lambda_stochastic(num_states, num_actions, gamma,  lambda_trials)
            cost_hat = get_cost_from_grad(gamma, state_space, lambda_hat)
            
            xi_trials = generate_trials(env_dict, m_xi2, H_xi2, state_init=None)
            grad_xi = get_stochastic_grad_xi(num_states, num_actions, gamma, xi=env_dict['xi'], trials=xi_trials, costs_hat=cost_hat)
            
            xi = proj_L2_xi(xi + 0.02 * (grad_xi - 2*L*(xi - xi_tilde)),xi0,xi_radius,Psi=None,Psi_proj=None) # a[k, t]
            env_dict['xi'] = xi
            
        sample_cnt += m_lambda2*H_lambda2
        
        xis_tilde[k] = xi
        thetas_tilde[k] = pi
        lambda_trials = generate_trials(env_dict, m_lambda3, H_lambda3, state_init=None)
        lambda_hat = get_lambda_stochastic(num_states, num_actions, gamma,  lambda_trials)
        cost_hat = get_cost_from_grad(gamma, state_space, lambda_hat)
        theta_trials = generate_trials(env_dict, m_theta3, H_theta3, state_init=None)
        grad_theta = get_stochastic_grad_theta(num_states, num_actions, gamma, pi, theta_trials, cost_hat)
        pi = proj_L2_pi(pi - 0.02*grad_theta) # b[k]=0.02
        env_dict['pi'] = pi
        
        barz = np.concatenate([pi.flatten(), xi.flatten()])
        F_barz = obtain_exact_gradient(env_dict)
        norm = np.linalg.norm(z_proj(env_dict, barz - 0.02 * F_barz, num_states, num_actions, check=False) - barz)/0.02
        norms[k + K] = norm
        complexs[k + K] = sample_cnt
        print(f'Phase 2: Iter {k}/{K_prime}, norm={norm:.6f}')
        
    k_prime = np.random.randint(0, K_prime)
    return thetas_tilde[k_prime], xis_tilde[k_prime], norms


def ps_gda(env_dict, theta_init, xi_init, K, T, trial_nums1, trial_lens1, trial_nums2, trial_lens2, alpha, beta):
    '''part 1 of algorithm 1: Projected Stochastic Gradient Descent Ascent (PS-GDA) Algorithm'''
    m_lambda1, m_theta1 = trial_nums1['lambda'], trial_nums1['theta'] 
    H_lambda1, H_theta1 = trial_lens1['lambda'], trial_lens1['theta']
    
    m_lambda2, m_xi2 = trial_nums2['lambda'], trial_nums2['theta'] 
    H_lambda2, H_xi2 = trial_lens2['lambda'], trial_lens2['theta']
    
    state_space = env_dict['state_space']
    num_states = env_dict['num_states']
    num_actions = env_dict['num_actions']
    gamma = env_dict['gamma']
    xi_radius = env_dict['xi_radius']
    xi0 = env_dict['xi0']
    
    
    xi_k = xi_init
    norms = np.zeros(K)
                     
    for k in range(K):
        theta_k = theta_init
        for t in range(T):
            # inner iteration for theta
            lambda_trials = generate_trials(env_dict, m_lambda1, H_lambda1, state_init=None)
            lambda_hat = get_lambda_stochastic(num_states, num_actions, gamma, lambda_trials)
            costs_hat = get_cost_from_grad(gamma, state_space, lambda_hat)
            
            # get grad theta
            theta_trials = generate_trials(env_dict, m_theta1, H_theta1, state_init=None)
            grad_theta = get_stochastic_grad_theta(num_states, num_actions, gamma, theta_k, theta_trials, costs_hat)
            grad_theta = grad_theta.reshape(num_states, num_actions)
            
            # claculate norm
            F_barz1 = obtain_exact_gradient(env_dict)
            theta_exact = theta_k - alpha * F_barz1[:num_states*num_actions].reshape(num_states, num_actions)
            norm_theta = np.linalg.norm(proj_L2_pi(theta_exact) - theta_k) / alpha
            print(f'Phase 1: Inner Iter {t}/{T}, norm_theta={norm_theta:.7f}')
            
            # projection on theta
            theta_next = theta_k - alpha * grad_theta
            theta_k = proj_L2_pi(theta_next) # using alpha
            env_dict['pi'] = theta_k # update theta
            
            
        xi_trials = generate_trials(env_dict, m_xi2, H_xi2, state_init=None)
        lambda_trials = generate_trials(env_dict, m_lambda2, H_lambda2, state_init=None)
        lambda_hat = get_lambda_stochastic(num_states, num_actions, gamma,  lambda_trials)
        cost_hat = get_cost_from_grad(gamma, state_space, lambda_hat)
        grad_xi = get_stochastic_grad_xi(num_states, num_actions, gamma, xi=env_dict['xi'], trials=xi_trials, costs_hat=cost_hat)
        
        

        F_barz = obtain_exact_gradient(env_dict)
        xi_exact = xi_k + beta * F_barz[num_states*num_actions:].reshape(num_states, num_actions, num_states)
        norm_xi = np.linalg.norm(proj_L2_xi(xi_exact, xi0, xi_radius,Psi=None,Psi_proj=None) - xi_k) / beta
        print(f'Phase 1: Iter {k}/{K}, norm_xi={norm_xi:.7f}, norm={norm_theta+norm_xi:.7f}')
        norms[k] = norm_theta+norm_xi
        xi_prev = xi_k + beta*grad_xi
        xi_k = proj_L2_xi(xi_prev, xi0, xi_radius,Psi=None,Psi_proj=None)
        env_dict['xi'] = xi_k
        
    return theta_k, xi_k, norms


def corrected_phase(env_dict, theta_init, xi_init, K_prime, T_prime, trial_nums3, trial_lens3, trial_nums4, trial_lens4, L, a, b):
    print('Begin corrected phase to solve the corrected optimization problem (14)...')
    
    num_actions = env_dict['num_actions']
    num_states = env_dict['num_states']
    state_space = env_dict['state_space']
    gamma = env_dict['gamma']
    xi0, xi_radius = env_dict['xi0'], env_dict['xi_radius']
    
    m_lambda3, m_theta3, m_xi3 = trial_nums3['lambda'], trial_nums3['theta'], trial_nums3['xi']
    H_lambda3, H_theta3, H_xi3 = trial_lens3['lambda'], trial_lens3['theta'], trial_lens3['xi']
    
    m_lambda4, m_theta4, m_xi4 = trial_nums4['lambda'], trial_nums4['theta'], trial_nums4['xi']
    H_lambda4, H_theta4, H_xi4 = trial_lens4['lambda'], trial_lens4['theta'], trial_lens4['xi']
    
    thetas_tilde = np.zeros(shape=(K_prime, num_states, num_actions))
    xis_tilde = np.zeros(shape=(K_prime, num_states, num_actions, num_states))
    
    xi_tilde = xi_init
    xi = xi_init
    pi = theta_init
    norms_prime = np.zeros(K_prime)
    for k in (range(K_prime)):
        for t in range(T_prime):
            lambda_trials = generate_trials(env_dict, m_lambda3, H_lambda3, state_init=None)
            lambda_hat = get_lambda_stochastic(num_states, num_actions, gamma, lambda_trials)
            cost_hat = get_cost_from_grad(gamma, state_space, lambda_hat)
            
            xi_trials = generate_trials(env_dict, m_xi3, H_xi3, state_init=None)
            grad_xi = get_stochastic_grad_xi(num_states, num_actions, gamma, xi=env_dict['xi'], trials=xi_trials, costs_hat=cost_hat)
            
            xi = proj_L2_xi(xi + a * (grad_xi - 2*L*(xi - xi_tilde)),xi0,xi_radius,Psi=None,Psi_proj=None) # a[k, t]
            env_dict['xi'] = xi
        
        xis_tilde[k] = xi
        thetas_tilde[k] = pi
        lambda_trials = generate_trials(env_dict, m_lambda4, H_lambda4, state_init=None)
        lambda_hat = get_lambda_stochastic(num_states, num_actions, gamma,  lambda_trials)
        cost_hat = get_cost_from_grad(gamma, state_space, lambda_hat)
        theta_trials = generate_trials(env_dict, m_theta4, H_theta4, state_init=None)
        grad_theta = get_stochastic_grad_theta(num_states, num_actions, gamma, pi, theta_trials, cost_hat)
        pi = proj_L2_pi(pi - b*grad_theta) # b[k]=0.02
        env_dict['pi'] = pi
        

        F_barz = obtain_exact_gradient(env_dict)
        theta_exact = pi - b * F_barz[:num_states*num_actions].reshape(num_states, num_actions)
        norm = np.linalg.norm(proj_L2_pi(theta_exact - b*theta_exact))/b
        norms_prime[k] = norm
        print(f'Phase 2: Iter {k}/{K_prime}, norm={norm:.6f}')
        
    k_prime = K_prime - 1
    return thetas_tilde[k_prime], xis_tilde[k_prime], norms_prime
    

def algo1(num_states, num_actions, K=50, K_prime=50, T=50, T_prime=50, L=10, alpha=0.01, beta=0.02, a=0.01, b=0.02):
    seed_init = np.random.randint(0,5)
    
    state_space = range(0,num_states)
    action_space = range(0,num_actions)
    env_dict = env_setup(seed_init,state_space,action_space,rho=None,Psi=None,xi0=None,xi_radius=0.01,cost=None,gamma=0.95)
    
    trial_nums1 = {'lambda':50, 'theta':50, 'xi': 50} 
    trial_lens1 = {'lambda':40, 'theta':40, 'xi':40} 
    
    trial_nums2 = {'lambda':50, 'theta':50, 'xi': 50} 
    trial_lens2 = {'lambda':40, 'theta':40, 'xi':40} 
    
    theta_init = env_dict['pi']
    xi_init = env_dict['xi']
    
    print('Begin original phase to solve the original optimization problem (2)...')
    theta_k, xi_k, norms = ps_gda(env_dict, theta_init, xi_init, K, T, trial_nums1, trial_lens1, trial_nums2, trial_lens2, alpha, beta) # algo2
    
    
    trial_nums3 = {'lambda':50, 'theta':50, 'xi': 50} 
    trial_lens3 = {'lambda':40, 'theta':40, 'xi':40} 
    
    trial_nums4 = {'lambda':50, 'theta':50, 'xi': 50} 
    trial_lens4 = {'lambda':40, 'theta':40, 'xi':40} 
    theta_tilde, xi_tilde, norms_prime = corrected_phase(env_dict, theta_k, xi_k, K_prime, T_prime, trial_nums3, trial_lens3, trial_nums4, trial_lens4, L, a, b)
    
    norms_final = np.concatenate((norms, norms_prime))
    
    plt.plot(np.array(range(K+K_prime)), norms_final)
    plt.xlabel(f'Iter k/{K+K_prime}', fontsize=14)
    plt.ylabel(r'$\frac{1}{\beta}\|{\rm proj}_{\Theta\times\Xi}{{z}_k + \beta F({z}_k)] - {z}_k\|}$', fontsize=14)
    plt.title('Norm vs. Iterations', fontsize=14)
    plt.grid(True)
    plt.tight_layout()
    current_time_str = datetime.now().strftime("%Y%m%d_%H%M%S")
    plt.savefig('figs3/algo1_norms_iters_{}.png'.format(current_time_str))
    

def algo1_new(num_states, num_actions, K=50, K_prime=50, T=50, T_prime=50, L=10, alpha=1/48, beta=0.2, a=0.02, b=0.02, step_size=1):
    '''
    Algorithm 1 Approximate Gradient Projection Ascent Algorithm for Convex Utility
    '''
    seed_init = np.random.randint(0,5)
    
    state_space = range(0,num_states)
    action_space = range(0,num_actions)
    env_dict = env_setup(seed_init,state_space,action_space,rho=None,Psi=None,xi0=None,xi_radius=0.01,cost=None,gamma=0.95)
    
    trial_nums = {'lambda':15, 'theta':15, 'xi': 15} # m_lambda, m_theta, m_xi
    trial_lens = {'lambda':100, 'theta':100, 'xi':100} # H_lambda, H_theta, H_xi
    m_lambda, m_theta, m_xi = trial_nums['lambda'], trial_nums['theta'], trial_nums['xi']
    H_lambda, H_theta, H_xi = trial_lens['lambda'], trial_lens['theta'], trial_lens['xi']

    pi = env_dict['pi']
    xi = env_dict['xi']
    rho = env_dict['rho']
    gamma = env_dict['gamma']
    xi_radius = env_dict['xi_radius']
    xi0 = env_dict['xi0']
    
    print = tqdm.write
    norms = np.zeros(K+K_prime)
    # PHASE I (begin)
    print('Start Phase I....')
    complexs = np.zeros(K+K_prime)
    sample_cnt = 0
    for k in tqdm(range(K)):
        # print(f'{k}|{K}')
        pi = project_policy_gradient_decent(env_dict, gamma, pi, xi, alpha, T,  trial_nums, trial_lens)
        env_dict['pi'] = pi
        
        xi_trials = generate_trials(env_dict, m_xi, H_xi, state_init=None)

        lambda_trials = generate_trials(env_dict, m_lambda, H_lambda, state_init=None)
        lambda_hat = get_lambda_stochastic(num_states, num_actions, gamma,  lambda_trials)
        cost_hat = get_cost_from_grad(gamma, state_space, lambda_hat)
        
        
        grad_xi = get_stochastic_grad_xi(num_states, num_actions, gamma, xi=env_dict['xi'], trials=xi_trials, costs_hat=cost_hat)
        xi_prev = xi + beta*grad_xi
        xi = proj_L2_xi(xi_prev,xi0,xi_radius,Psi=None,Psi_proj=None)
        env_dict['xi'] = xi
        
        barz = np.concatenate([pi.flatten(), xi.flatten()])
        F_barz = obtain_exact_gradient(env_dict)
        
        norm = np.linalg.norm(z_proj(env_dict, barz + step_size * F_barz, num_states, num_actions, check=False) - barz)/abs(step_size)
        sample_cnt += T * (m_lambda*H_lambda + m_theta*H_theta) + (m_xi*H_xi + m_lambda*H_lambda)
        complexs[k] = sample_cnt
        norms[k] = norm
        print(f'Phase 1: Iter {k}/{K}, norm={norm:.6f}')
    xi_tilde = xi
    
    # PHASE II
    trial_nums2 = {'lambda':10, 'theta':10, 'xi': 10} # m_lambda, m_theta, m_xi
    trial_lens2 = {'lambda':100, 'theta':100, 'xi':100} # H_lambda, H_theta, H_xi
    m_lambda2, m_theta2, m_xi2 = trial_nums2['lambda'], trial_nums2['theta'], trial_nums2['xi']
    H_lambda2, H_theta2, H_xi2 = trial_lens2['lambda'], trial_lens2['theta'], trial_lens2['xi']
    
    trial_nums3 = {'lambda':10, 'theta':10, 'xi': 10} # m_lambda, m_theta, m_xi
    trial_lens3 = {'lambda':100, 'theta':100, 'xi':100} # H_lambda, H_theta, H_xi
    m_lambda3, m_theta3, m_xi3 = trial_nums3['lambda'], trial_nums3['theta'], trial_nums3['xi']
    H_lambda3, H_theta3, H_xi3 = trial_lens3['lambda'], trial_lens3['theta'], trial_lens3['xi']
    
    thetas_tilde = np.zeros(shape=(K_prime, num_states, num_actions))
    xis_tilde = np.zeros(shape=(K_prime, num_states, num_actions, num_states))
    
    print('Start Phase II....')
    for k in tqdm(range(K_prime)):
        for t in range(T_prime):
            lambda_trials = generate_trials(env_dict, m_lambda2, H_lambda2, state_init=None)
            lambda_hat = get_lambda_stochastic(num_states, num_actions, gamma,  lambda_trials)
            cost_hat = get_cost_from_grad(gamma, state_space, lambda_hat)
            
            xi_trials = generate_trials(env_dict, m_xi2, H_xi2, state_init=None)
            grad_xi = get_stochastic_grad_xi(num_states, num_actions, gamma, xi=env_dict['xi'], trials=xi_trials, costs_hat=cost_hat)
            
            xi = proj_L2_xi(xi + a * (grad_xi - 2*L*(xi - xi_tilde)),xi0,xi_radius,Psi=None,Psi_proj=None) # a[k, t]
            env_dict['xi'] = xi
        
        xis_tilde[k] = xi
        thetas_tilde[k] = pi
        lambda_trials = generate_trials(env_dict, m_lambda3, H_lambda3, state_init=None)
        lambda_hat = get_lambda_stochastic(num_states, num_actions, gamma,  lambda_trials)
        cost_hat = get_cost_from_grad(gamma, state_space, lambda_hat)
        theta_trials = generate_trials(env_dict, m_theta3, H_theta3, state_init=None)
        grad_theta = get_stochastic_grad_theta(num_states, num_actions, gamma, pi, theta_trials, cost_hat)
        pi = proj_L2_pi(pi - b*grad_theta) # b[k]=0.02
        env_dict['pi'] = pi
        
        barz = np.concatenate([pi.flatten(), xi.flatten()])
        F_barz = obtain_exact_gradient(env_dict)
        norm = np.linalg.norm(z_proj(env_dict, barz + step_size * F_barz, num_states, num_actions, check=False) - barz)/abs(step_size)
        sample_cnt += T_prime * (m_xi3*H_xi3 + m_lambda3*H_lambda3) + (m_lambda2*H_lambda2 + m_theta2*H_theta2) 
        norms[k + K] = norm
        complexs[k+K] = sample_cnt
        print(f'Phase 2: Iter {k}/{K_prime}, norm={norm:.6f}')
        
    k_prime = np.random.randint(0, K_prime)
    return thetas_tilde[k_prime], xis_tilde[k_prime], norms, complexs


def algo2(num_states, num_actions, K):
    seed_init = np.random.randint(0,5)
    
    state_space = range(0,num_states)
    action_space = range(0,num_actions)
    env_dict = env_setup(seed_init,state_space,action_space,rho=None,Psi=None,xi0=None,xi_radius=0.01,cost=None,gamma=0.95)
    
    trial_nums = {'lambda':100, 'theta':100, 'xi': 100} # m_lambda, m_theta, m_xi
    trial_lens = {'lambda':10, 'theta':10, 'xi':10} # H_lambda, H_theta, H_xi
    
    pi = env_dict['pi']
    xi = env_dict['xi']
    rho = env_dict['rho']
    gamma = env_dict['gamma']
    m_lambda1, m_theta1 = trial_nums['lambda'], trial_nums['theta'] 
    H_lambda1, H_theta1 = trial_lens['lambda'], trial_lens['theta']
    
    # V(\Xi)
    for k in range(K):
        # lambda_trials = generate_trials(env_dict, m_lambda1, H_lambda1, state_init=None) # depend on theta and xi
        # lambda_hat = get_lambda_stochastic(num_states, num_actions, gamma, lambda_trials)
        # costs_hat = get_cost_from_grad(gamma, state_space, lambda_hat)
        exact_grad = obtain_exact_gradient(env_dict)
    