import warnings
warnings.filterwarnings('ignore')
import os
os.environ['MKL_SERVICE_FORCE_INTEL'] = '1'
#os.environ['MUJOCO_GL'] = 'egl'
import torch
import numpy as np
import gym
gym.logger.set_level(40)
import time
import random
from pathlib import Path
from cfg import parse_cfg
from env import make_env
from algorithm.tdmpc import TDMPC
from algorithm.helper import Episode, ReplayBuffer
import logger
torch.backends.cudnn.benchmark = True
__CONFIG__, __LOGS__ = 'cfgs', 'logs'

#from dreamfusion.guidance.sd_utils import StableDiffusion
import clip
from PIL import Image

import open_clip

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def evaluate(env, agent, num_episodes, step, env_step, video):
    """Evaluate a trained agent and optionally save a video."""
    episode_rewards = []
    for i in range(num_episodes):
        obs, done, ep_reward, t = env.reset(), False, 0, 0
        if video: video.init(env, enabled=(i==0))
        while not done:
            action = agent.plan(obs, eval_mode=True, step=step, t0=t==0)
            obs, reward, done, _ = env.step(action.cpu().numpy())
            ep_reward += reward
            if video: video.record(env)
            t += 1
        episode_rewards.append(ep_reward)
        if video: video.save(env_step)
    return np.nanmean(episode_rewards)


def train(cfg):
    """Training script for TD-MPC. Requires a CUDA-enabled device."""
    assert torch.cuda.is_available()
    set_seed(cfg.seed)
    work_dir = Path().cwd() / __LOGS__ / cfg.task / cfg.modality / cfg.exp_name / str(cfg.seed)
    env, agent, buffer = make_env(cfg), TDMPC(cfg), ReplayBuffer(cfg)

    # sds pixel configs
    domain, task = cfg.task.replace('-', '_').split('_', 1)
    camera_id = dict(quadruped=2).get(domain, 0)
    render_kwargs = dict(height=480, width=480, camera_id=camera_id)

    # setup SDS model
    device = torch.device('cuda')

    '''
    clip_model, clip_preprocess = clip.load("RN50x64") #TODO: can try to make this arg
    clip_model.eval().to(device)
    input_resolution = clip_model.visual.input_resolution
    context_length = clip_model.context_length
    vocab_size = clip_model.vocab_size

    prompt_tokens = clip.tokenize([cfg.text_prompt]).to(device)
    with torch.no_grad():
        prompt_features = clip_model.encode_text(prompt_tokens).float()
    prompt_features /= prompt_features.norm(dim=-1, keepdim=True)
    '''
    
    model, _, preprocess_val = open_clip.create_model_and_transforms('hf-hub:laion/CLIP-ViT-bigG-14-laion2B-39B-b160k')
    tokenizer = open_clip.get_tokenizer('hf-hub:laion/CLIP-ViT-bigG-14-laion2B-39B-b160k')

    model.eval().to(device)
    prompt_tokens = tokenizer([cfg.text_prompt]).to(device)
    with torch.no_grad():
        prompt_features = model.encode_text(prompt_tokens).float()
    prompt_features /= prompt_features.norm(dim=-1, keepdim=True)

    # Run training
    L = logger.Logger(work_dir, cfg)
    episode_idx, start_time = 0, time.time()
    for step in range(0, cfg.train_steps+cfg.episode_length, cfg.episode_length):

        # Collect trajectory
        obs = env.reset()
        episode = Episode(cfg, obs)
        #source_noise = torch.randn((1, 4, 64, 64)).to(device) # refresh consistent source noise each episode
        while not episode.done:
            action = agent.plan(obs, step=step, t0=episode.first)
            obs, gt_reward, done, _ = env.step(action.cpu().numpy())

            pil_frame = Image.fromarray(env.render(height=480, width=480))
            clip_frame = preprocess_val(pil_frame).to(device)
            #clip_frame = clip_preprocess(pil_frame).to(device)
            with torch.no_grad():
                image_features = model.encode_image(clip_frame.unsqueeze(0)).float()
                #image_features = clip_model.encode_image(clip_frame.unsqueeze(0)).float()
            image_features /= image_features.norm(dim=-1, keepdim=True)
            reward = (prompt_features @ image_features.T)[0][0]# + 0.1*gt_reward
            #result = env.render(height=480, width=480)
            #np.save('humanoid_render.npy', result)
            episode += (obs, action, reward, gt_reward, done)
        assert len(episode) == cfg.episode_length
        buffer += episode

        # Update model
        train_metrics = {}
        if step >= cfg.seed_steps:
            num_updates = cfg.seed_steps if step == cfg.seed_steps else cfg.episode_length
            for i in range(num_updates):
                train_metrics.update(agent.update(buffer, step+i))

        # Log training episode
        episode_idx += 1
        env_step = int(step*cfg.action_repeat)
        common_metrics = {
            'episode': episode_idx,
            'step': step,
            'env_step': env_step,
            'total_time': time.time() - start_time,
            'episode_reward': episode.cumulative_reward,
            'episode_gt_reward': episode.cumulative_gt_reward}
        train_metrics.update(common_metrics)
        L.log(train_metrics, category='train', agent=agent)

        # Evaluate agent periodically
        if env_step % cfg.eval_freq == 0:
            common_metrics['episode_reward'] = evaluate(env, agent, cfg.eval_episodes, step, env_step, L.video)
            L.log(common_metrics, category='eval', agent=agent)

    L.finish(agent)
    print('Training completed successfully')


if __name__ == '__main__':
    train(parse_cfg(Path().cwd() / __CONFIG__))
