import argparse
from datetime import datetime
import gym
import numpy as np
import torch
import pickle
import random
from d4rl.infos import REF_MIN_SCORE, REF_MAX_SCORE
import os
import wandb
from tqdm import trange
from diffuser.utils.arrays import to_np
from diffuser.models.diffusion import GaussianInvDynDiffusion
from diffuser.models.temporal import TemporalUnet, AttTemporalUnet, TransformerNoise
from diffuser.models.encoder_transformer import EncoderTransformer
from diffuser.models.state_encoder_transformer import StateEncoderTransformer
from diffuser.training.all_trainer import AllTrainer
import warnings
warnings.simplefilter(action='ignore', category=DeprecationWarning)
from tensorboardX import SummaryWriter
import collections
from env import get_envs
from typing import Callable
from data.d4rl import get_dataset
import metaworld
from data.sequence import SequenceDataset

parser = argparse.ArgumentParser()
parser.add_argument('--env_name', type=str, default='reach-v2')
parser.add_argument('--K', type=int, default=100)
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--learning_rate', '-lr', type=float, default=2e-4)
parser.add_argument('--seed', type=int, default=100)
parser.add_argument('--max_iters', type=int, default=2000)
parser.add_argument('--z_dim', type=int, default=16)
parser.add_argument('--num_steps_per_iter', type=int, default=1000)
parser.add_argument('--device', type=str, default='cuda')
parser.add_argument('--condition_guidance_w', type=float, default=1.2)
parser.add_argument('--n_timesteps', type=int, default=200)
parser.add_argument('--repre_type', type=str, choices=['vec', 'dist', 'none'], default='dist')
parser.add_argument('--phi_norm_loss_ratio', type=float, default=0.1)
parser.add_argument('--w_lr', type=float, default=0.01)
parser.add_argument('--predict_x', action='store_true', default=False)
parser.add_argument('--use_transformer', action='store_true', default=False)
parser.add_argument('--info_loss_weight', type=float, default=0.1)
args = parser.parse_args()
variant = vars(args)

def seed(seed: int = 0):
  RANDOM_SEED = seed
  np.random.seed(RANDOM_SEED)
  torch.manual_seed(RANDOM_SEED)
  torch.cuda.manual_seed_all(RANDOM_SEED)
  random.seed(RANDOM_SEED)
seed(variant['seed']) # 0

def cycle(dl):
    while True:
        for data in dl:
            yield data

device = variant.get('device', 'cuda')
mt10 = metaworld.MT10()
env = mt10.train_classes[variant['env_name']]()
task = random.choice([task for task in mt10.train_tasks if task.env_name == variant['env_name']])
env.set_task(task)
max_ep_len = 500
# state_dim and act_dim are same for all MT-10 tasks, (39,4)
state_dim = env.observation_space.shape[0]
act_dim = env.action_space.shape[0]

n_eval = 5
eval_envs, eval_names = [], []
for name, env_cls in mt10.train_classes.items():
    eval_names.append(name)
    for i in range(n_eval):
        env = env_cls()
        task = random.choice([task for task in mt10.train_tasks if task.env_name == name])
        env.set_task(task)
        eval_envs.append(env)

dataset, dataloader = [], []
for i in range(len(eval_names)):
    data = get_dataset(eval_names[i], max_traj_length=500, include_next_obs=False, termination_penalty=0)
    dataset.append(SequenceDataset(data, horizon=variant['K'], max_traj_length=500, include_returns=True,use_padding=True))
    data_sampler = torch.utils.data.RandomSampler(dataset[i])
    dataloader.append(cycle(torch.utils.data.DataLoader(
            dataset[i],
            sampler=data_sampler,
            batch_size=variant['batch_size'],
            drop_last=True,
            num_workers=8,
        )))

K = variant['K'] # 20
batch_size = variant['batch_size'] # 64
z_dim = variant['z_dim'] # 8
print(f'z_dim is: {z_dim}')
repre_type = variant['repre_type']
discounts = 0.99 ** np.arange(1000)[:K]

def eval_episodes(envs, model, phi):
    avg_reward = [0. for _ in envs]
    all_success = [0. for _ in envs]
    dones = [False for _ in envs]
    phi = phi.repeat(n_eval, 1) # (1, 16), 100 is for 10 tasks, n_eval episodes
    obs_list = [env.reset()[0] for env in envs]
    success = [0. for _ in envs]
    for _ in trange(env.max_path_length, desc='eval'):
    # for _ in trange(10, desc='eval'):
        # observation = normalizer.normalize(np.array(obs_list), "observations") # (100, 39)
        observation = np.array(obs_list)
        conditions = torch.from_numpy(observation).to(device=device, dtype=torch.float32)
        with torch.no_grad():
            samples = model.conditional_sample(conditions, returns=phi) # condition在(s_t,R)上
        obs_comb = torch.cat([samples[:, 0, :], samples[:, 1, :]], dim=-1)# [s0, s1]
        obs_comb = obs_comb.reshape(-1, 2*state_dim)
        action = model.inv_model(obs_comb) # (100, 4)
        action = to_np(action)
        # action = normalizer.unnormalize(action, "actions") # (100,4)
        for i in range(len(envs)):
            if not dones[i]:
                j = i // 10
                action[i] = dataset[j].normalizer.unnormalize(action[i], 'actions')
                next_obs, reward, done, _, info = envs[i].step(action[i])
                dones[i] = done
                avg_reward[i] += reward
                next_obs = dataset[j].normalizer.normalize(next_obs, 'observations')
                obs_list[i] = next_obs
                if info['success']:
                    success[i] = 1.
            else:
                print(f'Env {eval_names[i]} Done.')
    all_success += success
    re_dict = [f'{env_name}-reward' for env_name in eval_names]
    su_dict = [f'{env_name}-success' for env_name in eval_names]
    dict = {}
    for i in range(len(eval_names)):
        dict[re_dict[i]] = np.array(avg_reward[i * n_eval:i * n_eval + n_eval]).sum() / n_eval
        dict[su_dict[i]] = np.array(all_success[i * n_eval:i * n_eval + n_eval]).sum() / n_eval
    print(dict)
    return dict

if variant['use_transformer']:
    noise_predictor = TransformerNoise(horizon=K, obs_dim=state_dim, z_dim=z_dim)
else:
    noise_predictor = TemporalUnet(horizon=K, transition_dim=state_dim, cond_dim=state_dim)
model = GaussianInvDynDiffusion(noise_predictor, horizon=K, observation_dim=state_dim, action_dim=act_dim,
                                condition_guidance_w=variant['condition_guidance_w'],
                                n_timesteps=variant['n_timesteps'],
                                predict_epsilon=not variant['predict_x'],
                                info_loss_weight=variant['info_loss_weight'],
                                repre_type=variant['repre_type'],
                                ).to(device=device)
optimizer = torch.optim.AdamW(model.parameters(), lr=variant['learning_rate'])

en_model = StateEncoderTransformer(
    state_dim=state_dim,
    act_dim=act_dim,
    hidden_size=128,
    output_size=z_dim,
    max_length=K,
    max_ep_len=max_ep_len,
    num_hidden_layers=3,
    num_attention_heads=2,
    intermediate_size=4*128,
    max_position_embeddings=1024,
    repre_type=variant['repre_type'],
    hidden_act='relu',
    hidden_dropout_prob=0.1,
    attention_probs_dropout_prob=0.1,
)
en_model = en_model.to(device=device)
et_optimizer = torch.optim.AdamW(en_model.parameters(), lr=1e-4, weight_decay=1e-4)

w = torch.randn((len(eval_names), z_dim)).to(device=device)
w.requires_grad = True
w_std = torch.randn((len(eval_names), z_dim)).to(device=device)
w_std.requires_grad = True
w_optimizer = torch.optim.AdamW([w, w_std], lr=variant["w_lr"], weight_decay=1e-4)

trainer = AllTrainer(
    en_model=en_model,
    de_model=model,
    optimizer=optimizer,
    batch_size=batch_size,
    get_batch=dataloader,
    device=device,
    et_optimizer=et_optimizer,
    w=w,
    w_std=w_std,
    w_optimizer=w_optimizer,
    repre_type=variant['repre_type'],
    phi_norm_loss_ratio=variant["phi_norm_loss_ratio"],
    info_loss_weight=variant['info_loss_weight'],
)

t = datetime.now().strftime('%Y%m%d%H%M%S')
name = f"{variant['env_name']}-{variant['seed']}-{t}"
net_name = 'tfm' if variant['use_transformer'] else 'unet'
info_loss_weight = variant['info_loss_weight']
condition_w = variant['condition_guidance_w']
supfix = f'{info_loss_weight}info-{condition_w}guide'
writer = SummaryWriter(f'./logs/{repre_type}-{net_name}-{name}-{supfix}-normalize')
# writer = SummaryWriter(f'./logs/train_all-{repre_type}-{name}')

folder = f"./saved_models/all_model_{repre_type}/{name}"
if not os.path.exists(folder):
    os.makedirs(folder)
torch.save((model.state_dict(), en_model.state_dict(), w), f"./{folder}/params_0.pt")

for iter in trange(variant['max_iters'], desc='epoch'): # 100
    if iter % 100 == 0 and iter != 0:
        if iter % 500 == 0:
            torch.save((trainer.ema_model.state_dict(), en_model.state_dict(), w), f"{folder}/diffusion_model_{iter}.pt")
        trainer.ema_model.eval()
        outputs = eval_episodes(eval_envs, trainer.ema_model, w)
        for key, values in outputs.items():
            writer.add_scalar(f'evaluation/{key}', values, global_step=iter)

    outputs = trainer.train_iteration(num_steps=variant['num_steps_per_iter'], iter_num=iter+1, print_logs=True) # 1000
    for key, values in outputs.items():
        writer.add_scalar(key, values, global_step=iter)

torch.save((trainer.ema_model.state_dict(), en_model.state_dict(), w), f"{folder}/diffusion_model_{variant['max_iters']}.pt")