from MDP_util import (
    generate_toy_cmdp,
    solve_cmdp,
    policy_evaluation,
    ope_practical,
    compute_occupancy_measure,
    critic,
    policy_mixture,
    MDP,
    CMDP,
    generate_trajectory,
)
import numpy as np
import sys
from tqdm import tqdm
# import argparse
# parser = argparse.ArgumentParser()
# parser.add_argument('--n', type=int, default=100)
# parser.add_argument('--C', type=float, default=5)
# parser.add_argument('--B', type=float, default=2)
# parser.add_argument('--K', type=int, default=2000)
# parser.add_argument('--eta', type=float, default=2)
# parser.add_argument('--eta_lambda', type=float, default=2)
# parser.add_argument('--tighten', type=float, default=0)
# parser.add_argument('--k_tighten', type=float, default=0)
# parser.add_argument('--seed', type=int)
# parser.add_argument('--seed_cmdp', type=int, default=1)
# parser.add_argument('--output_file')
# args = parser.parse_args()

def pi_player(pi: np.ndarray, h: np.ndarray, eta: float):
    pi_new = np.clip(np.multiply(pi, np.exp(eta * h)), 1e-6, 1e6)
    return pi_new / np.sum(pi_new, axis=1, keepdims=True)


def generate_data(cmdp: CMDP, mu: np.ndarray, n: int):
    ret = np.zeros((cmdp.num_states, cmdp.num_actions, cmdp.num_states))
    for _ in range(n):
        i = np.random.choice(np.arange(cmdp.num_states * cmdp.num_actions), p=mu.flatten())
        state = i // cmdp.num_actions
        action = i % cmdp.num_actions
        next_state = np.random.choice(np.arange(cmdp.num_states), p=cmdp.transition[state, action, :])
        ret[state, action, next_state] += 1
    return ret

# def generate_data(cmdp: CMDP, pi: np.ndarray, seed: int):
#     ret = np.zeros((cmdp.num_states, cmdp.num_actions, cmdp.num_states))
#     dataset = generate_trajectory(cmdp, pi, seed)
#     # print(dataset)
#     for traj in dataset:
#         for data in traj:
#             _, _, state, action, r, c, next_state, _, _ = data
#             ret[state, action, next_state] += 1
#     # for _ in range(n):
#     #     i = np.random.choice(np.arange(cmdp.num_states * cmdp.num_actions), p=mu.flatten())
#     #     state = i // cmdp.num_actions
#     #     action = i % cmdp.num_actions
#     #     next_state = np.random.choice(np.arange(cmdp.num_states), p=cmdp.transition[state, action, :])
#     #     ret[state, action, next_state] += 1
#     return ret

def run_wsac_tabular(dataset, cmdp, initial_pi, beta=2):

    eta = 2
    C = 5
    # beta = 2
    lambda_initial = 0
    lambda_eta = 1e-2
    lambda_max = 2
    K = 500

    tau = cmdp.cost_thresholds[0] / (1 - cmdp.gamma)
    cost_lim = cmdp.cost_thresholds[0]

    # initialize
    pi_list = [None] * K
    lambda_list = [None] * K
    lambda0_list = [None] * K
    Ar_list = [None] * K
    Ac_list = [None] * K
    f_list = [None] * K
    g_list = [None] * K

    pi_list[0] = initial_pi

    log = {
        "reward_value": [],
        "cost_value": [],
        "lambda": [],
    }
    v_r, _, v_c, _ = policy_evaluation(cmdp, pi_list[0])
    log["reward_value"].append(v_r[cmdp.initial_state])
    log["cost_value"].append(v_c[0, cmdp.initial_state])
    log["lambda"].append(lambda_list[0])
    # print(1, v_r[cmdp.initial_state], v_c[0, cmdp.initial_state], lambda_list[0])

    lambda_ = lambda_initial

    for k in tqdm(range(K - 1)):
        pi = pi_list[k]

        # lambda-player（pdca）
        # h, *_ = ope_practical(cmdp, cmdp.costs[0, :, :], pi, dataset)
        # h0 = np.sum(np.multiply(h[cmdp.initial_state, :], pi[cmdp.initial_state, :]))
        # l = B if tau - h0 < 0 else 0
        # lambda_list[k] = l

        # critics
        f, Ar, _ = critic(cmdp, cmdp.reward, pi, dataset, sign=1, C=C)
        g, Ac, _ = critic(cmdp, cmdp.costs[0, :, :], pi,dataset,sign=-1, C=C)
        # g_mu, _, _ = critic(cmdp, cmdp.costs[0, :, :], behavior_policy, dataset, sign=-1, C=C)
        v_r, v_c = v_r / (1 - cmdp.gamma), v_c / (1 - cmdp.gamma)
        f_list[k] = f
        g_list[k] = g
        # g_plus = np.clip(g - tau, 0, None)
        g_plus = np.clip(g - cost_lim, 0, None)
        # print(g_plus)
        f0 = np.sum(np.multiply(f[cmdp.initial_state, :], pi[cmdp.initial_state, :]))
        g0 = np.sum(np.multiply(g[cmdp.initial_state, :], pi[cmdp.initial_state, :]))
        g_plus0 = np.sum(np.multiply(g_plus[cmdp.initial_state, :], pi[cmdp.initial_state, :]))
        # print(g_plus0)
        # print(g)

        #update lambda

        lambda_ = np.clip(lambda_ + g_plus0 * lambda_eta, lambda_initial, lambda_max)
        lambda_list[k] = lambda_

        # pi-player
        z = (f - lambda_ * g_plus) * (1 - cmdp.gamma) / (beta + 1)
        pi_list[k + 1] = pi_player(pi, z, eta)

        v_r, _, v_c, _ = policy_evaluation(cmdp, pi_list[k + 1])
        # print(k + 2, v_r[cmdp.initial_state] / v_opt, v_c[0, cmdp.initial_state] / c_opt, lambda_list[k], file=output_file)
        # print(k + 2, v_r[cmdp.initial_state], v_c[0, cmdp.initial_state], lambda_list[k])
        log["reward_value"].append(v_r[cmdp.initial_state])
        log["cost_value"].append(v_c[0, cmdp.initial_state])
        log["lambda"].append(lambda_list[k])
    
    return log

if __name__ == '__main__':
    num_states = 10
    num_next_states = 2
    num_actions = 5
    gamma = 0.8
    n = 2000
    seed = 12354
    cmdp = generate_toy_cmdp(num_states, num_actions, gamma, num_next_states)
    if seed is not None:
        np.random.seed(seed)
    else:
        np.random.seed()
    # offline dataset
    # dataset = generate_data(cmdp, d_mu, n)
    pi_star = solve_cmdp(cmdp)
    pi_uniform = np.ones((num_states, num_actions)) / num_actions

    pi_mu = policy_mixture(pi_star, pi_uniform, alpha=0.5)
    d_mu = compute_occupancy_measure(cmdp, pi_mu)
    dataset = generate_data(cmdp, d_mu, n)
    run_wsac_tabular(dataset, cmdp)