
import pickle
import numpy as np
from contextlib import contextmanager
from itertools import permutations, product
from clus.models.peftpool.dual_l2m import DyLoRABookModelOracle



# <<<<<<<<<< multi-stage metaworld >>>>>>>>>> #
def get_task_list_equal_easy(all_task_flag='full'):
    task_sets = [['puck', 'drawer', 'button', 'door']] # easy_door 
    ss =[0, 3, 4, 7, 8, 11]  # all task cover | half
    bound = 12 # hyper parameter for task boundary
    if all_task_flag == 'full' : 
        ss = [i for i in range(12)]
    if all_task_flag == 'half' : 
        ss = [0, 3, 4, 7, 8, 11] 
    if all_task_flag == 'third' : 
        ss = [0, 3, 4, 7, 8, 11]
    if all_task_flag == "sixth":
        ss = [0, 10, 13, 23]
        bound = 24
    task_shuffled = []
    count = 0
    for task_set in task_sets :
        for i, d in enumerate(permutations(task_set)):
            d = list(d)
            if i%bound  in ss :
                task_dict = {
                    'skill_list' : list(task_set),
                    'skill_seq' : list(d) 
                }
                task_shuffled.append( task_dict )
            count += 1
        # input()
    return task_shuffled # list of ( skill_list , skill_seq ) dictionary

def get_task_list_equal_normal(all_task_flag='full', only_normal=False):
    task_sets=[]

    print( f"tasks normal\n" * 1 )
    for i in product(('box', 'puck'), ('handle', 'drawer'), ('button', 'lever'), ('door', 'stick')):
        if 'box' in i or 'stick' in i :
            continue
        if only_normal :
            if 'handle' not in i and 'lever' not in i :
                continue
        task_sets.append(list(i))

    task_shuffled = []
    count = 0
    ss =[0, 3, 4, 7, 8, 11]  # all task cover | half
    bound = 12 # hyper parameter for task boundary
    if all_task_flag == 'full' : 
        ss = [i for i in range(12)]
    if all_task_flag == 'half' : 
        ss = [0, 3, 4, 7, 8, 11] 
    if all_task_flag == 'third' : 
        ss = [0, 3, 4, 7, 8, 11]
    if all_task_flag == "sixth":
        ss = [0, 10, 13, 23]
        bound = 24

    for task_set in task_sets :
        for i, d in enumerate(permutations(task_set)):
            d = list(d)
            if i%bound in ss : #NOTE
                task_dict = {
                    'skill_list' : list(task_set),
                    'skill_seq' : list(d) 
                }
                task_shuffled.append( task_dict )
            count += 1
        # input()
    return task_shuffled # list of ( skill_list , skill_seq ) dictionary

def get_task_list_equal_hard(all_task_flag='full'):
    task_sets=[]

    print( f"tasks hard\n" * 1 )
    for i in product(('box', 'puck'), ('handle', 'drawer'), ('button', 'lever'), ('door', 'stick')):
        task_sets.append(list(i))

    task_shuffled = []
    count = 0
    ss =[0, 3, 4, 7, 8, 11]  # all task cover | half
    bound = 12 # hyper parameter for task boundary
    if all_task_flag == 'full' : 
        ss = [i for i in range(12)]
    if all_task_flag == 'half' : 
        ss = [0, 3, 4, 7, 8, 11] 
    if all_task_flag == 'third' : 
        ss = [0, 3, 4, 7, 8, 11]
    if all_task_flag == "sixth":
        ss = [0, 10, 13, 23]
        bound = 24

    for task_set in task_sets :
        for i, d in enumerate(permutations(task_set)):
            d = list(d)
            if i%bound in ss : #NOTE
                task_dict = {
                    'skill_list' : list(task_set),
                    'skill_seq' : list(d) 
                }
                task_shuffled.append( task_dict )
            count += 1
    return task_shuffled # list of ( skill_list , skill_seq ) dictionary


def configs_task_list(configs) :
    task_shuffled = get_task_list_equal_hard()
    for task in task_shuffled :
        task['data_name'] = "-".join(task['skill_seq'])

    task_refined = []
    for phase in configs :
        for tasks in phase['data_name'].split(',') :
            if type(tasks) == str :
                tasks = [tasks]

            for task in tasks :
                # find task in task_shuffle
                for task_dict in task_shuffled :
                    if task == task_dict['data_name'] :
                        task_refined.append(task_dict)
                        break
                
    return task_refined

from tqdm import tqdm

try: 
    from mmworld.envs.mujoco.sawyer_xyz.v2.sawyer_non_stationary_v2 import SawyerNonStationaryEnvV2
except :
    print("mmworld not installed")
import random as py_rand
import gym
from clus.env.base_evaluator  import BaseEvaluator
import cv2
import matplotlib.pyplot as plt
from PIL import Image

class SingleTask(gym.Env):
    def __init__(self, seed: int, skill_list, obs_type='sensor', max_episode_length=1000, partially_observable=False):
        py_rand.seed(seed)
        self.env = SawyerNonStationaryEnvV2(skill_list)

        self.max_episode_length = max_episode_length
        self.time_steps = 0

        self.obs_type = obs_type
        self.partially_observable = partially_observable

        if self.obs_type == 'vision':
            self.observation_space = gym.spaces.Box(low=np.zeros((80, 80, 3)), high=np.ones((80, 80, 3)), dtype=np.uint8)
            self.env._partially_observable = self.partially_observable
        if self.obs_type == 'mixed':
            self.env._partially_observable = True
            self.observation_space = gym.spaces.Dict({'image': gym.spaces.Box(low=np.zeros((80, 80, 3)), high=np.ones((80, 80, 3)), dtype=np.uint8),
                                      'sensor': self.env.observation_space})
        elif self.obs_type == 'sensor':
            self.env._partially_observable = self.partially_observable
            self.observation_space = self.env.observation_space

        self.action_space = self.env.action_space

    def step(self, action, action_noise=None):
        '''
        action : normalized action (-1, 1)
        action_noise : action noise
        '''
        sensor_obs, reward, done, info = None, None, None, None
        if action_noise is not None :
            sensor_obs, reward, done, info = self.env.step(action, action_noise)
        else :
            sensor_obs, reward, done, info = self.env.step(action)
        self.time_steps += 1

        if self.time_steps == self.max_episode_length:
            done = True

        if self.obs_type == 'vision':
            obs = self.render()
        elif self.obs_type == 'mixed':
            image_obs = self.render()
            obs = {'image': image_obs, 'sensor': sensor_obs}
        else:
            obs = sensor_obs

        info['action_noise'] = action_noise
        return obs, reward, done, info

    def reset(self):

        sensor_obs = self.env.reset()
        self.time_steps = 0
        if self.obs_type == 'vision':
            obs = self.render()
        elif self.obs_type == 'mixed':
            image_obs = self.render()
            obs = {'image': image_obs, 'sensor': sensor_obs}
        else:
            obs = sensor_obs
        return obs

    def render(self, mode='corner3', resolution=(224,224)):
        return self.env.render(offscreen=True, resolution=resolution, camera_name=mode)

from metaworld.envs.mujoco.env_dict import ALL_V2_ENVIRONMENTS 
from clus.env.continual_config import *
import metaworld 

class CWEvaluator(BaseEvaluator) :
    def __init__(
            self,
            phase_configures=CW10,
            eval_mode='obs', # obs or traj
            traj_length=10, # used for traj eval mode
            eval_episodes=3,
        ) -> None:

        print("[Continual World Evaluator]")
        self.phase_configures = phase_configures

        self.MT50 = metaworld.MT50(seed=777)

        self.evaluation_list = []
        self.env_list = []
        for data_dict in self.phase_configures :
            for path in data_dict['data_paths'] :
                env_name = path.split('/')[-2].split('.')[0]
                print("[env setting ...]" , env_name)
                env = ALL_V2_ENVIRONMENTS[env_name]()
                task = [
                    task for task in self.MT50.train_tasks if task.env_name == env_name
                ][0]
                env.set_task(task)
                self.evaluation_list.append(env_name)
                self.env_list.append(env)


        self.eval_horizons = 200

        self.threshold = len(self.evaluation_list)
        self.eval_mode = eval_mode
        self.traj_length = traj_length
        self.eval_episodes = eval_episodes

    def evaluate_base(
            self,
            model,
            eval_fn = None,
        ) :
        eval_episodes = self.eval_episodes
        rew_info = {'skill_seq':[], 'skill_rew' : []}
        eval_fn = model.eval_model if eval_fn is None else eval_fn
        
        
        used_unique = []
        unique = None
        # obs processing 
        for eval_seed in range(eval_episodes) :
            for eid , env in enumerate(self.env_list) :
                task = [
                    task for task in self.MT50.train_tasks if task.env_name == self.evaluation_list[eid]
                ][eval_seed]
                env.set_task(task)
                
            history_obs_list = []
            obs_list = []
            done_list= []
            skill_idx_list=[]
            for e_idx, env in enumerate(self.env_list) :
                obs, _ = env.reset()
                obs_list.append(obs)
                done_list.append(False)
                skill_idx_list.append(0)


            dummy_obs = np.zeros_like(obs_list[0])
            
            for _ in tqdm(range(self.eval_horizons)) :
                # skill_semantics_list = []
                # for e_idx, env in enumerate(self.env_list) :
                #     sidx = min(skill_idx_list[e_idx],3)
                #     skill_semantics_list.append(self.skill_semantics[env.env.skill_list[sidx]])

                obs = np.array(obs_list)

                if self.eval_mode == 'obs' :
                    eval_res = eval_fn(obs[:,None,:])
                    if type(eval_res) == tuple :
                        actions , unique = eval_res
                    else :
                        actions = eval_res
                    actions = np.array(actions) # mmworld action space
                elif self.eval_mode == 'traj' :
                    history_obs_list = np.concatenate([history_obs_list, obs[:,None,:]],axis=1) \
                        if len(history_obs_list) > 0 else np.tile(obs[:,None,:], (1,self.traj_length,1))
                    if len(history_obs_list) > self.traj_length :
                        history_obs_list = history_obs_list[:, -self.traj_length:, :]
                    actions, unique = eval_fn(history_obs_list)
                    actions = np.array(actions)
                else :
                    raise ValueError(f"eval_mode {self.eval_mode} is not defined")

                if unique is not None :
                    used_unique.append(unique)

                obs_list = []
                actions = actions[...,:4]
                for e_idx, env in enumerate(self.env_list) :
                    # pass if done
                    if done_list[e_idx] is True:
                        obs_list.append(dummy_obs) # dummy
                        continue
                    asdf = env.step(actions[e_idx].squeeze())
                    obs, rew, done , _, env_info = env.step(actions[e_idx].squeeze())

                    obs_list.append(obs)
                    if env_info['success'] == 1 :
                        skill_idx_list[e_idx] += 1
                        done_list[e_idx] = True

                if done_list.count(True) == len(self.env_list) :
                    break

            for eid, env_name in enumerate(self.evaluation_list) :
                reward_sum = float(int(done_list[eid]))
                if eval_seed == 0 :
                    rew_info['skill_seq'].append(env_name)
                    rew_info['skill_rew'].append(reward_sum)
                else : 
                    rew_info['skill_rew'][rew_info['skill_seq'].index(env_name)] += reward_sum
            
            obs_list = []
            done_list= []
            skill_idx_list=[]
        # eval episodes for loop end
        reward_sum = 0
        for i , data in enumerate(rew_info['skill_seq']) :
            rew_info['skill_rew'][i] /= eval_episodes
            print("[{}]skill is  {} rew : {:.2f}".format(i,rew_info['skill_seq'][i], rew_info['skill_rew'][i]))
            reward_sum += rew_info['skill_rew'][i]
        
        print("total reward : ", reward_sum/len(rew_info['skill_seq']))
        if len(used_unique) > 0 :
            print("unique : ", np.unique(np.concatenate(used_unique)))

        eval_reward = reward_sum/len(rew_info['skill_seq'])
        return rew_info

class MMEvaluator_wo_h(BaseEvaluator) :
    def __init__(
            self,
            base_evaluation_sequences,
            eval_mode='obs', # obs or traj
            traj_length=10, # used for traj eval mode
            eval_episodes=3,
            phase_configures=None,
        ) -> None:

        print("[MMevaluator]")
        skill_embedding_path='data/continual_dataset/evolving_world/mm_lang_embedding.pkl'
        with open( skill_embedding_path , 'rb' ) as f :
            self.skill_semantics = pickle.load(f)
        
        self.eval_horizons = 600

        self.base_evaluation_sequences = base_evaluation_sequences
        self.threshold = len(self.base_evaluation_sequences)
        self.eval_mode = eval_mode
        self.traj_length = traj_length
        self.eval_episodes = eval_episodes

        self.env_list = []
        for idx , task in enumerate(tqdm(self.base_evaluation_sequences)) :
            # env initialize
            skill_list = task['skill_list']
            env = SingleTask(seed=777, skill_list=skill_list) # From m-metaworld
            env.env.skill_list = task['skill_seq']
            skill_seq = task['skill_seq'] 
            self.env_list.append(env)
            if len(self.env_list) < self.threshold : 
                continue
        
        # self.daco_query = []
        # if phase_configures is not None :
        #     self.phase_configures = phase_configures
        #     print("[DacoRL feature lazy loading]")
        #     for config in tqdm(self.phase_configures) :
        #         for path in config['data_paths'] :
        #             with open(path, 'rb') as f :
        #                 data = pickle.load(f)
        #             ep_done = np.where(np.array(data['terminals']) == 1)[0][0]
        #             first_traj = np.array(data['observations'])[:ep_done]

        #             skills = np.array(data['skills'])[:ep_done]
        #             skill_embs = np.array([self.skill_semantics[i] for i in skills])
        #             first_traj = np.concatenate([first_traj, skill_embs], axis=-1)
                    
        #             context = np.mean(first_traj, axis=0)
        #             self.daco_query.append(context)
        #     self.daco_query = np.array(self.daco_query)[:,None,:] # ( env ,1, 60)
        #     print("daco query : ", len(self.daco_query))
        #     print("daco query shape : ", self.daco_query.shape)

    def evaluate_base(
            self,
            model,
            eval_fn = None,
        ) :
        eval_episodes = self.eval_episodes
        rew_info = {'skill_seq':[], 'skill_rew' : []}
        daco_flag = True if type(model) == DyLoRABookModelOracle else False
        eval_fn = model.eval_model if eval_fn is None else eval_fn
        used_unique = []
        unique = None
        # obs processing 
        for eval_seed in range(eval_episodes) :
            # reset the environment
            history_obs_list = []
            obs_list = []
            done_list= []
            skill_idx_list=[]
            for e_idx, env in enumerate(self.env_list) :
                obs = env.reset()
                obs_list.append(obs)
                done_list.append(False)
                skill_idx_list.append(0)

            dummy_obs = np.zeros_like(obs_list[0])
            
            for _ in tqdm(range(self.eval_horizons)) :
                
                # # Move to high level policy
                # skill_semantics_list = []
                # for e_idx, env in enumerate(self.env_list) :
                #     sidx = min(skill_idx_list[e_idx],3)
                #     skill_semantics_list.append(self.skill_semantics[env.env.skill_list[sidx]])
                # obs = np.concatenate([obs_list, skill_semantics_list], axis=-1)

                obs = np.array(obs_list)
                eval_inputs = {
                    'obs' : obs,
                    'meta_data' : None,
                    'instruction' : None,
                }

                eval_res = eval_fn(eval_inputs)
                actions = np.array(eval_res['actions'])

                # if self.eval_mode == 'obs' :
                #     if daco_flag :
                #         eval_res = eval_fn(obs[:,None,:], daco_query=self.daco_query)
                #     else :
                #         eval_res = eval_fn(obs[:,None,:])
                #     # post processing
                #     if type(eval_res) == tuple :
                #         actions , unique = eval_res
                #     else :
                #         actions = eval_res
                #     actions = np.array(actions) # mmworld action space
                # elif self.eval_mode == 'traj' :
                #     history_obs_list = np.concatenate([history_obs_list, obs[:,None,:]],axis=1) \
                #         if len(history_obs_list) > 0 else np.tile(obs[:,None,:], (1,self.traj_length,1))
                #     if len(history_obs_list) > self.traj_length :
                #         history_obs_list = history_obs_list[:, -self.traj_length:, :]
                #     actions, unique = eval_fn(history_obs_list)
                #     actions = np.array(actions)
                # else :
                #     raise ValueError(f"eval_mode {self.eval_mode} is not defined")

                if unique is not None :
                    used_unique.append(unique)

                # shaping
                obs_list = []
                actions = actions[...,:4]
                for e_idx, env in enumerate(self.env_list) :
                    # pass if done
                    if done_list[e_idx] is True:
                        obs_list.append(dummy_obs) # dummy
                        continue
                    obs, rew, done, env_info = env.step(actions[e_idx].squeeze())
                    
                    obs_list.append(obs)
                    if env_info['success'] == 1 :
                        skill_idx_list[e_idx] += 1
                        if done :
                            done_list[e_idx] = True
                if done_list.count(True) == len(self.env_list) :
                    break

            for env in self.env_list :
                skill_seq = env.env.skill_list
                reward_sum = int(env.env.mode)
                if eval_seed == 0 :
                    rew_info['skill_seq'].append(skill_seq)
                    rew_info['skill_rew'].append(reward_sum)
                else : 
                    rew_info['skill_rew'][rew_info['skill_seq'].index(skill_seq)] += reward_sum
            
            obs_list = []
            done_list= []
            skill_idx_list=[]
        # eval episodes for loop end
        reward_sum = 0
        for i , data in enumerate(rew_info['skill_seq']) :
            rew_info['skill_rew'][i] /= eval_episodes
            print("[{}]skill is  {} rew : {:.2f}".format(i,rew_info['skill_seq'][i], rew_info['skill_rew'][i]))
            reward_sum += rew_info['skill_rew'][i]
        
        print("total reward : ", reward_sum/len(rew_info['skill_seq']))
        if len(used_unique) > 0 :
            print("unique : ", np.unique(np.concatenate(used_unique)))

        eval_reward = reward_sum/len(rew_info['skill_seq'])
        return rew_info

class MMEvaluator(BaseEvaluator) :
    def __init__(
            self,
            base_evaluation_sequences,
            eval_mode='obs', # obs or traj
            traj_length=10, # used for traj eval mode
            eval_episodes=3,
            phase_configures=None,
        ) -> None:

        print("[MMevaluator]")
        skill_embedding_path='data/continual_dataset/evolving_world/mm_lang_embedding.pkl'
        with open( skill_embedding_path , 'rb' ) as f :
            self.skill_semantics = pickle.load(f)
        
        self.eval_horizons = 600

        self.base_evaluation_sequences = base_evaluation_sequences
        self.threshold = len(self.base_evaluation_sequences)
        self.eval_mode = eval_mode
        self.traj_length = traj_length
        self.eval_episodes = eval_episodes

        self.env_list = []
        for idx , task in enumerate(tqdm(self.base_evaluation_sequences)) :
            # env initialize
            skill_list = task['skill_list']
            env = SingleTask(seed=777, skill_list=skill_list) # From m-metaworld
            env.env.skill_list = task['skill_seq']
            skill_seq = task['skill_seq'] 
            self.env_list.append(env)
            if len(self.env_list) < self.threshold : 
                continue
        
        self.daco_query = []
        if phase_configures is not None :
            self.phase_configures = phase_configures
            print("[DacoRL feature lazy loading]")
            for config in tqdm(self.phase_configures) :
                for path in config['data_paths'] :
                    with open(path, 'rb') as f :
                        data = pickle.load(f)
                    ep_done = np.where(np.array(data['terminals']) == 1)[0][0]
                    first_traj = np.array(data['observations'])[:ep_done]

                    skills = np.array(data['skills'])[:ep_done]
                    skill_embs = np.array([self.skill_semantics[i] for i in skills])
                    first_traj = np.concatenate([first_traj, skill_embs], axis=-1)
                    
                    context = np.mean(first_traj, axis=0)
                    self.daco_query.append(context)
            self.daco_query = np.array(self.daco_query)[:,None,:] # ( env ,1, 60)
            print("daco query : ", len(self.daco_query))
            print("daco query shape : ", self.daco_query.shape)

    def evaluate_base(
            self,
            model,
            eval_fn = None,
        ) :
        eval_episodes = self.eval_episodes
        rew_info = {'skill_seq':[], 'skill_rew' : []}
        daco_flag = True if type(model) == DyLoRABookModelOracle else False
        eval_fn = model.eval_model if eval_fn is None else eval_fn
        used_unique = []
        unique = None
        # obs processing 
        for eval_seed in range(eval_episodes) :
            # reset the environment
            history_obs_list = []
            obs_list = []
            done_list= []
            skill_idx_list=[]
            for e_idx, env in enumerate(self.env_list) :
                obs = env.reset()
                obs_list.append(obs)
                done_list.append(False)
                skill_idx_list.append(0)

            dummy_obs = np.zeros_like(obs_list[0])
            
            for _ in tqdm(range(self.eval_horizons)) :
                skill_semantics_list = []
                for e_idx, env in enumerate(self.env_list) :
                    sidx = min(skill_idx_list[e_idx],3)
                    skill_semantics_list.append(self.skill_semantics[env.env.skill_list[sidx]])

                obs = np.concatenate([obs_list, skill_semantics_list], axis=-1)

                if self.eval_mode == 'obs' :
                    if daco_flag :
                        eval_res = eval_fn(obs[:,None,:], daco_query=self.daco_query)
                    else :
                        eval_res = eval_fn(obs[:,None,:])
                    # post processing
                    if type(eval_res) == tuple :
                        actions , unique = eval_res
                    else :
                        actions = eval_res
                    actions = np.array(actions) # mmworld action space
                elif self.eval_mode == 'traj' :
                    history_obs_list = np.concatenate([history_obs_list, obs[:,None,:]],axis=1) \
                        if len(history_obs_list) > 0 else np.tile(obs[:,None,:], (1,self.traj_length,1))
                    if len(history_obs_list) > self.traj_length :
                        history_obs_list = history_obs_list[:, -self.traj_length:, :]
                    actions, unique = eval_fn(history_obs_list)
                    actions = np.array(actions)
                else :
                    raise ValueError(f"eval_mode {self.eval_mode} is not defined")

                if unique is not None :
                    used_unique.append(unique)

                obs_list = []
                actions = actions[...,:4]
                for e_idx, env in enumerate(self.env_list) :
                    # pass if done
                    if done_list[e_idx] is True:
                        obs_list.append(dummy_obs) # dummy
                        continue
                    obs, rew, done, env_info = env.step(actions[e_idx].squeeze())
                    
                    obs_list.append(obs)
                    if env_info['success'] == 1 :
                        skill_idx_list[e_idx] += 1
                        if done :
                            done_list[e_idx] = True
                if done_list.count(True) == len(self.env_list) :
                    break

            for env in self.env_list :
                skill_seq = env.env.skill_list
                reward_sum = int(env.env.mode)
                if eval_seed == 0 :
                    rew_info['skill_seq'].append(skill_seq)
                    rew_info['skill_rew'].append(reward_sum)
                else : 
                    rew_info['skill_rew'][rew_info['skill_seq'].index(skill_seq)] += reward_sum
            
            obs_list = []
            done_list= []
            skill_idx_list=[]
        # eval episodes for loop end
        reward_sum = 0
        for i , data in enumerate(rew_info['skill_seq']) :
            rew_info['skill_rew'][i] /= eval_episodes
            print("[{}]skill is  {} rew : {:.2f}".format(i,rew_info['skill_seq'][i], rew_info['skill_rew'][i]))
            reward_sum += rew_info['skill_rew'][i]
        
        print("total reward : ", reward_sum/len(rew_info['skill_seq']))
        if len(used_unique) > 0 :
            print("unique : ", np.unique(np.concatenate(used_unique)))

        eval_reward = reward_sum/len(rew_info['skill_seq'])
        return rew_info

    def evaluate_base_vid(
            self,
            model,
            eval_fn = None,
        ) :
        eval_episodes = 1
        rew_info = {'skill_seq':[], 'skill_rew' : []}
        eval_fn = model.eval_model if eval_fn is None else eval_fn

        # TEMPORAL 
        mp4v = cv2.VideoWriter_fourcc(*'mp4v')
        k = 720
        video_path = f"../videos/1120_mmworld_trace_example.mp4"
        vid_size = (2*k,k) # (W,H)
        video = cv2.VideoWriter(video_path, mp4v, 30, vid_size)
        def render_action(action, temp_path='data/tmp_act.png' ) :
            plt.clf()
            plt.rcParams.update({'font.size': 6})
            data = action

            a2 = action.copy()
            # data -= [-0.1,0.25,0.1,0]
            a2 += np.random.uniform(-0.1,0.1,4)

            angles = np.linspace(0, 2 * np.pi, len(data), endpoint=False).tolist()
            angles += angles[:1]  
            data = np.concatenate([data, data[:1]])  
            a2 = np.concatenate([a2, a2[:1]])  

            fig = plt.figure(figsize=(7.20, 7.20))
            ax = fig.add_subplot(111, polar=True)  
            
            ax.plot(angles, data, 'o-')  
            ax.fill(angles, data, alpha=0.25)  
            # ax.plot(angles, data, 'o-', color='orange')  
            # ax.fill(angles, data, alpha=0.25, color='orange')  

            # ax.plot(angles, a2, 'o-')  
            # ax.fill(angles, a2, alpha=0.25)  

            ax.set_xticks(angles[:-1])  
            ax.set_xticklabels([f'' for i in range(len(data)-1)]) 
            ax.set_ylim([-1.5, 1.6])
            ax.set_yticks(np.arange(-1, 1.51, 0.5))
            plt.tight_layout() 

            plt.savefig(temp_path, dpi=100)
            plt.close()

            image = Image.open(temp_path).convert('RGB')
            aciton_trace_image = np.array(image)

            return aciton_trace_image

        # obs processing 
        for eval_seed in range(eval_episodes) :
            # reset the environment
            history_obs_list = []
            obs_list = []
            done_list= []
            skill_idx_list=[]
            for e_idx, env in enumerate(self.env_list) :
                obs = env.reset()
                obs_list.append(obs)
                done_list.append(False)
                skill_idx_list.append(0)

            dummy_obs = np.zeros_like(obs_list[0])
            
            for _ in tqdm(range(self.eval_horizons)) :
                skill_semantics_list = []
                for e_idx, env in enumerate(self.env_list) :
                    skill_semantics_list.append(self.skill_semantics[env.env.skill_list[min(skill_idx_list[e_idx],3)]])

                obs = np.concatenate([obs_list, skill_semantics_list], axis=-1)

                if self.eval_mode == 'obs' :
                    actions = np.array(eval_fn(obs[:,None,:]))
                elif self.eval_mode == 'traj' :
                    history_obs_list = np.concatenate([history_obs_list, obs[:,None,:]],axis=1) \
                        if len(history_obs_list) > 0 else np.tile(obs[:,None,:], (1,self.traj_length,1))
                    if len(history_obs_list) > self.traj_length :
                        history_obs_list = history_obs_list[:, -self.traj_length:, :]
                    actions = np.array(eval_fn(history_obs_list))
                else :
                    raise ValueError(f"eval_mode {self.eval_mode} is not defined")

                obs_list = []
                for e_idx, env in enumerate(self.env_list) :
                    # pass if done
                    if done_list[e_idx] is True:
                        obs_list.append(dummy_obs) # dummy
                        continue
                    obs, rew, done, env_info = env.step(actions[e_idx].squeeze())

                    img = env.render(resolution=(k,k))
                    img = cv2.putText(img, f"skill : {env.env.skill_list[min(skill_idx_list[e_idx],3)]}", (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
                    actimg = render_action(actions[0].squeeze()) #(H,W,C)
                    img = np.concatenate([img, actimg], axis=1)
                    BGR = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
                    video.write(BGR)

                    
                    obs_list.append(obs)
                    if env_info['success'] == 1 :
                        skill_idx_list[e_idx] += 1
                        if done :
                            done_list[e_idx] = True
                if done_list.count(True) == len(self.env_list) :
                    break

            for env in self.env_list :
                skill_seq = env.env.skill_list
                reward_sum = int(env.env.mode)
                if eval_seed == 0 :
                    rew_info['skill_seq'].append(skill_seq)
                    rew_info['skill_rew'].append(reward_sum)
                else : 
                    rew_info['skill_rew'][rew_info['skill_seq'].index(skill_seq)] += reward_sum
            
            obs_list = []
            done_list= []
            skill_idx_list=[]
        # eval episodes for loop end
        reward_sum = 0
        for i , data in enumerate(rew_info['skill_seq']) :
            rew_info['skill_rew'][i] /= eval_episodes
            print("skill is  {} rew : {:.2f}".format(rew_info['skill_seq'][i], rew_info['skill_rew'][i]))
            reward_sum += rew_info['skill_rew'][i]
        print("total reward : ", reward_sum/len(rew_info['skill_seq']))

        eval_reward = reward_sum/len(rew_info['skill_seq'])
        return eval_reward
    
    def multi_metaworld_evaluate(
            self,
            model,
            evaluation_sequences,
            eval_episodes=3,
        ) :

        rew_info = {'skill_seq': [], 'skill_rew' : []}
        env_list = []
        
        threshold = len(evaluation_sequences)

        for idx , task in enumerate(tqdm(evaluation_sequences)) :
            # env initialize
            skill_list = task['skill_list']
            env = SingleTask(seed=777, skill_list=skill_list) # From m-metaworld
            env.env.skill_list = task['skill_seq']
            skill_seq = task['skill_seq'] 
            env_list.append(env)

            if len(env_list) < threshold : 
                continue
        
            # obs processing 
            for eval_seed in range(eval_episodes) :
                # reset the environment
                obs_list = []
                done_list= []
                skill_idx_list=[]
                for e_idx, env in enumerate(env_list) :
                    obs = env.reset()
                    obs_list.append(obs)
                    done_list.append(False)
                    skill_idx_list.append(0)

                dummy_obs = np.zeros_like(obs_list[0])
                
                for _ in tqdm(range(self.eval_horizons)) :
                    skill_semantics_list = []
                    for e_idx, env in enumerate(env_list) :
                        skill_semantics_list.append(self.skill_semantics[env.env.skill_list[min(skill_idx_list[e_idx],3)]])

                    obs = np.concatenate([obs_list, skill_semantics_list], axis=-1)
                    actions = np.array(model.eval_model(obs[:,None,:]))

                    obs_list = []
                    for e_idx, env in enumerate(env_list) :
                        # pass if done
                        if done_list[e_idx] is True:
                            obs_list.append(dummy_obs) # dummy
                            continue
                        obs, rew, done, env_info = env.step(actions[e_idx].squeeze())
                        
                        obs_list.append(obs)
                        if env_info['success'] == 1 :
                            skill_idx_list[e_idx] += 1
                            if done :
                                done_list[e_idx] = True
                    if done_list.count(True) == len(env_list) :
                        break

                for env in env_list :
                    skill_seq = env.env.skill_list
                    reward_sum = int(env.env.mode)
                    print("skill is  ", skill_seq, " rew : " , reward_sum)
                    rew_info['skill_seq'].append(skill_seq)
                    rew_info['skill_rew'].append(reward_sum)
                obs_list = []
                done_list= []
                skill_idx_list=[]
            # eval episodes for loop end
            env_list = []

        reward_sum = 0
        for i , data in enumerate(rew_info['skill_seq']) :
            reward_sum += rew_info['skill_rew'][i]
        print("total reward : ", reward_sum/len(rew_info['skill_seq']))

        eval_reward = reward_sum/len(rew_info['skill_seq'])
        return eval_reward


if __name__ == '__main__' :
    # cwdataloader = ContiualMetaworldDataloader()
    mwloader = CWEvaluator()

    mwloader.evaluate_base(None, eval_fn=lambda x : np.random.uniform(-1,1,(10,4)) )

    