import argparse
from datetime import datetime
import gym
import numpy as np
import torch
import pickle
import random
from d4rl.infos import REF_MIN_SCORE, REF_MAX_SCORE
import os
import wandb
from tqdm import trange
from diffuser.utils.arrays import to_np
from diffuser.models.diffusion import GaussianInvDynDiffusion
from diffuser.models.temporal import TemporalUnet, AttTemporalUnet, TransformerNoise
from diffuser.models.encoder_transformer import EncoderTransformer
from diffuser.training.prefdiffuser_trainer import PrefDiffuserTrainer
import warnings
warnings.simplefilter(action='ignore', category=DeprecationWarning)
from tensorboardX import SummaryWriter

parser = argparse.ArgumentParser()
parser.add_argument('--env_name', type=str, default='hopper-medium-expert')
parser.add_argument('--K', type=int, default=100)
parser.add_argument('--pct_traj', type=float, default=1.0)
parser.add_argument('--batch_size', type=int, default=16)
parser.add_argument('--learning_rate', '-lr', type=float, default=2e-4)
parser.add_argument('--seed', type=int, default=3333)
parser.add_argument('--max_iters', type=int, default=5000)
parser.add_argument('--z_dim', type=int, default=16)
parser.add_argument('--num_steps_per_iter', type=int, default=100)
parser.add_argument('--device', type=str, default='cuda')
parser.add_argument('--condition_guidance_w', type=float, default=1.5)
parser.add_argument('--n_timesteps', type=int, default=200)
parser.add_argument('--repre_type', type=str, choices=['vec', 'dist', 'vq_vec'], default='dist')
parser.add_argument('--encoder_code', type=int, required=True)
args = parser.parse_args()
variant = vars(args)

def seed(seed: int = 0):
  RANDOM_SEED = seed
  np.random.seed(RANDOM_SEED)
  torch.manual_seed(RANDOM_SEED)
  torch.cuda.manual_seed_all(RANDOM_SEED)
  random.seed(RANDOM_SEED)
seed(variant['seed']) # 0

device = variant.get('device', 'cuda')
env_name = variant['env_name']
env = gym.make(f'{env_name}-v2')
env.reset(seed=variant['seed'])
max_ep_len = 1000
state_dim = env.observation_space.shape[0]
act_dim = env.action_space.shape[0]

# load dataset
dir_path = variant.get('dirpath', '.') # current path
dataset_path = f'{dir_path}/data/{env_name}-v2.pkl'
with open(dataset_path, 'rb') as f:
    trajectories = pickle.load(f)

# save all path information into separate lists
states, traj_lens, returns = [], [], []
for path in trajectories:
    states.append(path['observations'])
    traj_lens.append(len(path['observations']))
    returns.append(path['rewards'].sum())
traj_lens, returns = np.array(traj_lens), np.array(returns)

# used for input normalization
states = np.concatenate(states, axis=0) # [19999400, 11]
state_mean, state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6

num_timesteps = sum(traj_lens)

print('=' * 50)
print(f'Starting new experiment: {env_name}')
print(f'{len(traj_lens)} trajectories, {num_timesteps} timesteps found')
print(f'Average return: {np.mean(returns):.2f}, std: {np.std(returns):.2f}')
print(f'Max return: {np.max(returns):.2f}, min: {np.min(returns):.2f}')
print('=' * 50)

K = variant['K'] # 20
batch_size = variant['batch_size'] # 64
pct_traj = variant.get('pct_traj', 1.) # 1.

z_dim = variant['z_dim'] # 8
print(f'z_dim is: {z_dim}')

expert_score = REF_MAX_SCORE[f"{variant['env_name']}-v2"]
random_score = REF_MIN_SCORE[f"{variant['env_name']}-v2"]
print(f"max score is: {expert_score}, min score is {random_score}")

# only train on top pct_traj trajectories (for %BC experiment)
num_timesteps = max(int(pct_traj*num_timesteps), 1) # 999906
sorted_inds = np.argsort(returns)  # lowest to highest, 2186维的数组
num_trajectories = 1
timesteps = traj_lens[sorted_inds[-1]] # 1000
ind = len(trajectories) - 2
while ind >= 0 and timesteps + traj_lens[sorted_inds[ind]] <= num_timesteps:
    timesteps += traj_lens[sorted_inds[ind]]
    num_trajectories += 1
    ind -= 1
sorted_inds = sorted_inds[-num_trajectories:]

def get_batch(batch_size=256, max_len=K): # K=20
    batch_inds = np.random.choice(np.arange(num_trajectories), size=batch_size, replace=True)  # reweights so we sample according to timesteps
    s, a, timesteps, mask = [], [], [], []
    for i in range(batch_size):
        traj = trajectories[int(sorted_inds[batch_inds[i]])] # 从trajectories里面取一条轨迹
        si = random.randint(0, traj['rewards'].shape[0] - 1) # 从traj的整个长度取一个索引
        # get sequences from dataset
        s.append(traj['observations'][si:si + max_len].reshape(1, -1, state_dim)) # 20步的state，存入s的维度是[1,20,11]
        a.append(traj['actions'][si:si + max_len].reshape(1, -1, act_dim)) # 20步的action
        timesteps.append(np.arange(si, si + s[-1].shape[1]).reshape(1, -1)) #时间步的索引，长度为20的list
        tlen = s[-1].shape[1]
        s[-1] = np.concatenate([s[-1], np.zeros((1, max_len - tlen, state_dim))], axis=1)
        s[-1] = (s[-1] - state_mean) / state_std
        a[-1] = np.concatenate([a[-1], np.ones((1, max_len - tlen, act_dim)) * -10.], axis=1)
        timesteps[-1] = np.concatenate([timesteps[-1], np.zeros((1, max_len - tlen))], axis=1)
        mask.append(np.concatenate([np.ones((1, tlen)), np.zeros((1, max_len - tlen))], axis=1))
    s = torch.from_numpy(np.concatenate(s, axis=0)).to(dtype=torch.float32, device=device)
    a = torch.from_numpy(np.concatenate(a, axis=0)).to(dtype=torch.float32, device=device)
    timesteps = torch.from_numpy(np.concatenate(timesteps, axis=0)).to(dtype=torch.long, device=device)
    mask = torch.from_numpy(np.concatenate(mask, axis=0)).to(device=device)
    return s, a, timesteps, mask

def eval_episodes(model, phi, state_mean, state_std):
    envs = gym.vector.make(f'{env_name}-v2', num_envs=10)
    states = envs.reset()
    dones = [False for _ in range(10)]
    s_mean = torch.from_numpy(state_mean).to(device=device)
    s_std = torch.from_numpy(state_std).to(device=device)
    episode_returns, episode_lengths = [0 for _ in range(10)], [0 for _ in range(10)]
    phi = phi[0].unsqueeze(0).repeat(10,1)
    for _ in trange(max_ep_len, desc='evaluation'):
        states = torch.from_numpy(states).to(device=device, dtype=torch.float32)
        conditions = (states - s_mean) / s_std
        # samples = model.conditional_sample(conditions, returns=phi)# condition在(s_t,R)上
        with torch.no_grad():
            samples = model.dpm_sample(conditions, returns=phi)
        obs_comb = torch.cat([samples[:, 0, :], samples[:, 1, :]], dim=-1)# [s0, s1]
        obs_comb = obs_comb.reshape(-1, 2*state_dim)
        actions = model.inv_model(obs_comb)#由逆动态模型来得到action
        actions = to_np(actions)
        states, rewards, dones, _ = envs.step(actions)
        episode_returns += rewards * (1-dones)
        episode_lengths += 1 * (1-dones)

        if dones.all():
            break
    norm_ret = (episode_returns - random_score) / (expert_score - random_score) * 100
    return {
            f'target_return_mean': np.mean(episode_returns),
            f'target_return_std': np.std(episode_returns),
            f'target_norm_return_mean': np.mean(norm_ret),
            f'target_norm_return_std': np.std(norm_ret),
            f'target_length_mean': np.mean(episode_lengths),
            f'target_length_std': np.std(episode_lengths),
        }

# noise_predictor = TemporalUnet(horizon=K, transition_dim=state_dim, cond_dim=state_dim)
# noise_predictor = AttTemporalUnet(horizon=K, transition_dim=state_dim)
noise_predictor = TransformerNoise(horizon=K, obs_dim=state_dim)
model = GaussianInvDynDiffusion(noise_predictor, horizon=K, observation_dim=state_dim, action_dim=act_dim,
                                condition_guidance_w=variant['condition_guidance_w'],
                                n_timesteps=variant['n_timesteps']).to(device=device)
optimizer = torch.optim.AdamW(model.parameters(), lr=variant['learning_rate'])

en_model = EncoderTransformer(
    state_dim=state_dim,
    act_dim=act_dim,
    hidden_size=128,
    output_size=z_dim,
    max_length=K,
    max_ep_len=max_ep_len,
    num_hidden_layers=3,
    num_attention_heads=2,
    intermediate_size=4*128,
    max_position_embeddings=1024,
    repre_type=variant['repre_type'],
    hidden_act='relu',
    hidden_dropout_prob=0.1,
    attention_probs_dropout_prob=0.1,
)
en_model = en_model.to(device=device)
repre_type = variant['repre_type']
encoder_code = variant['encoder_code']
encoder_path = f'saved_models/encoder_{repre_type}/{env_name}-3333-{encoder_code}/params_100.pt'
saved_model = torch.load(os.path.join(dir_path, encoder_path), map_location=device)
en_model.load_state_dict(saved_model[0])
w = saved_model[1]

trainer = PrefDiffuserTrainer(
    en_model=en_model,
    de_model=model,
    optimizer=optimizer,
    batch_size=batch_size,
    get_batch=get_batch,
    device=device,
)

t = datetime.now().strftime('%Y%m%d%H%M%S')
name = f"{variant['env_name']}-{variant['seed']}-{t}"
writer = SummaryWriter(f'./logs/train_diffusion-{repre_type}-{name}')

folder = f"{dir_path}/saved_models/diffusion_model_dist/{name}"
if not os.path.exists(folder):
    os.mkdir(folder)

for iter in trange(variant['max_iters'], desc='epoch'): # 100
    if iter % 50 == 0:
        torch.save((model.state_dict()), f"{folder}/diffusion_model_{iter}.pt")
        # trainer.de_model.eval()
        # logs = dict()
        # outputs = eval_episodes(trainer.de_model, w, state_mean, state_std)
        # for key, values in outputs.items():
        #     writer.add_scalar(f'evaluation/{key}', values, global_step=iter)

    outputs = trainer.train_iteration(num_steps=variant['num_steps_per_iter'], iter_num=iter+1, print_logs=True) # 1000
    for key, values in outputs.items():
        writer.add_scalar(key, values, global_step=iter)

torch.save((model.state_dict()), f"{folder}/diffusion_model_{variant['max_iters']}.pt")