"""Behavior Cloning Environment."""

from pathlib import Path
from typing import Any, Optional
from PIL import Image

from envs.meta_world import MetaWorldIndexedMultiTaskTester, MT30_TASK, MT10_TASK, MT50_TASK, SELECTED_TASK, JUST_ONE_TASK
from envs.env_base import DynamicResourceEnv, DynamicSequentialRouting, DDNNRouting, DynamicDistillRouting, PolicySelectEnv

import carla
import flax
import gym.spaces
import gym
import numpy as np
import tqdm
import copy
import os
import random
import cv2

from stable_baselines3.sac.policies import MultiInputPolicy as torch_MultiInputPolicy
from stable_baselines3.common.callbacks import CheckpointCallback
from stable_baselines3.common.buffers import DictReplayBuffer as torch_DictReplayBuffer

from algos.models.drnet import DRNet_Network
from algos.models.slimmable import SlimmableNetwork
from algos.network import SensorMultiInputCNN as torch_SensorMultiInputCNN
from algos.bc import BC as torch_BC
from algos.multipolicyreinforce import MultiPolicyREINFORCE
from algos.policies import REINFORCEPolicy
from algos.buffers import BCReplayBuffer as torch_BCReplayBuffer

from carla_env.utils.roaming_agent import RoamingAgent
from carla_env.base import BaseCarlaEnvironment
from carla_env.simulator.actor import Actor
from carla_env.dataset import load_datasets
from carla_env.utils.config import ExperimentConfigs
from carla_env.utils.lidar import generate_lidar_bin
from carla_env.utils.vector import to_array
from carla_env.utils.visualize import process_image

from utils.callbacks import RoutingCallback

from offline_baselines_jax.common.callbacks import CheckpointCallback as jax_CheckpointCallback
from offline_baselines_jax.common.jax_layers import ConvSoftModule, MultiInputCNN, SensorConvSoftModule, SensorMultiInputCNN, SensorSubPolicyMultiInputCNN
from offline_baselines_jax.common.buffers import DictReplayBuffer, TaskDictReplayBuffer, BCReplayBuffer, TeacherReplaybuffer
from offline_baselines_jax.bc.bc import BC
from offline_baselines_jax.bc.policies import MultiInputPolicy as BC_MultiInputPolicy
from offline_baselines_jax.reinforce.reinforce import REINFORCE
from offline_baselines_jax.reinforce.policies import CnnPolicy
from offline_baselines_jax.sac import SAC
from offline_baselines_jax.sac.policies import MultiInputPolicy as SAC_MultiInputPolicy
from offline_baselines_jax.soft_modularization_reinforce.soft_modularization_reinforce import SoftModularization
from offline_baselines_jax.soft_modularization_reinforce.policies import MultiInputPolicy
# from offline_baselines_jax.soft_modularization.soft_modularization import SoftModularization
# from offline_baselines_jax.soft_modularization.policies import MultiInputPolicy
from offline_baselines_jax.distillation.distillation import Distillation


Params = flax.core.FrozenDict[str, Any]


class BehaviorCloningCarlaEnvironment(BaseCarlaEnvironment):
    """Behavior Cloning Environment."""

    def __init__(self, config: ExperimentConfigs, dynamic_velocity: bool = False, mode: str = None, routing_model=None, eval_mode: bool = False, port: int = None):
        # dummy variables, to match deep mind control's APIs
        sampling_resolution = 0.25 if dynamic_velocity else 0.1
        super().__init__(config, port, sampling_resolution=sampling_resolution)
        self.action_space = gym.spaces.Box(shape=(2,), low=-1, high=1)
        observation_space = {
                "obs": gym.spaces.Box(
                    shape=(24, ), low=-1, high=1
                ),
                'task': gym.spaces.Box(
                    low=np.zeros(self.sim.route_manager.route_selector.route_len),
                    high=np.ones(self.sim.route_manager.route_selector.route_len)
                ),
                "image": gym.spaces.Box(
                    shape=(160, 160, 3), low=0, high=255, dtype=np.uint8
                ),
            }

        self.new_key = None
        if mode in ["ours", "iter", "distill", "ddnn"]:
            observation_space["module_select"] = gym.spaces.Box(low=0, high=1, shape=(16,))
            self.res_num_obs = 16
            self.res_max = 16
            self.num_resource = np.ones(self.res_num_obs)
            self.ddnn_bias = 0
            self.new_key = "module_select"


        elif mode == "slimmable":
            observation_space["column"] = gym.spaces.Box(low=0, high=1, shape=(5,))
            self.res_num_obs = 5
            self.num_resource = np.ones(self.res_num_obs)
            self.new_key = "column"


        elif mode == "rlaa":
            observation_space["policy_restrict"] = gym.spaces.Box(low=0, high=1, shape=(4,))
            self.res_num_obs = 4
            self.num_resource = np.ones(self.res_num_obs)
            self.new_key = "policy_restrict"

        if dynamic_velocity:
            self.throt_max = 10
            self.throt_min = 20
            observation_space["target_velocity"] = gym.spaces.Box(low=self.throt_min, high=self.throt_max, shape=(1,))
            observation_space["velocity"] = gym.spaces.Box(low=0, high=100, shape=(1,))

        self.agent = None
        self.mode = mode
        self.num_task = self.sim.route_manager.route_selector.route_len

        self.observation_space = gym.spaces.Dict(observation_space)
        self.routing_model = routing_model
        self.count = 0
        self.res = 0
        self.fixed = False
        self.dynamic_velocity = dynamic_velocity
        self.velocity_reset_count = 0

        self.target_velocity = 20

        self.eval_mode = eval_mode
        self.image_save_mode = False
        self.eval_log = []
        if self.eval_mode:
            # print(self.sim.world.carla)
            # self.sim.world.get_instance_tagging_style()
            # self.sim.world.set_instance_tagging_style("internal_id")

            self.velocity_success_log = []
            self.velocity_log = []


        self.min_reward = -1
        self.max_reward = -1
        self.last_expert_action = None

        actor_list = self.sim.world.get_actors()

        # print("Actor IDs in the world:", actor_list)

    def goal_reaching_reward(self):
        has_collided = self.sim.ego_vehicle.collision_sensor.has_collided
        lane_invasion = self.sim.ego_vehicle.lane_invasion_sensor.lane_types

        lane_done = (
            has_collided
            or carla.LaneMarkingType.Solid in lane_invasion
            or carla.LaneMarkingType.SolidSolid in lane_invasion
        )


        dist = self.get_distance_vehicle_target()

        total_reward, reward_dict, done_dict = super().goal_reaching_reward()
        if has_collided or lane_invasion:
            total_reward -= 10.0

        reward_dict = {
            "dist": dist,
            **reward_dict,
        }

        done_dict = {
            "lane_collision_done": lane_done,
            "dist_done": dist < 15,
            **done_dict,
        }
        return total_reward, reward_dict, done_dict

    def _simulator_step(
        self,
        action: Optional[np.ndarray] = None,
        traffic_light_color: Optional[str] = None,
    ):
        expert_action = self.compute_action()[0]

        if action is None:
            throttle, steer, brake = 0.0, 0.0, 0.0
        else:
            steer = float(action[1])
            if action[0] >= 0.0:
                throttle = float(action[0])
                brake = 0.0
            else:
                throttle = 0.0
                brake = float(action[0])

            vehicle_control = carla.VehicleControl(
                throttle=throttle,  # [0,1]
                steer=steer,  # [-1,1]
                brake=brake,  # [0,1]
                hand_brake=False,
                reverse=False,
                manual_gear_shift=False,
            )

            self.sim.ego_vehicle.apply_control(vehicle_control)
        # if self.eval_mode and self.agent is not None:
        #     print(action, self.agent.get_speed(), self.target_velocity)

        self.last_expert_action = expert_action
        _ = self.sync_mode.tick(timeout=100.0)
        # lidar_bin = generate_lidar_bin(
        #     lidar_sensor, self.config.lidar.num_theta_bin, self.config.lidar.max_range
        # )

        reward, reward_dict, done_dict = self.goal_reaching_reward()

        reward *= 1e-4

        if self.dynamic_velocity:
            if self.eval_mode:
                self.velocity_log.append(self.agent.get_speed())
                self.velocity_success_log.append(1 if self.target_velocity > self.agent.get_speed() else 0)
            velocity_loss = 0
            # reward *= min(1, 1 / abs(self.agent.get_speed() - self.target_velocity))
            if self.target_velocity < self.agent.get_speed():
                # reward -= (self.agent.get_speed() - self.target_velocity) * 5e-3
                reward *= min(1, 1 / abs(self.agent.get_speed() - self.target_velocity))

        if reward_dict['dist'] < self.min_reward:
            self.min_reward = reward_dict['dist']
            if self.min_reward < 15:
                self.min_reward = 0

        self.count += 1

        rotation = self.sim.ego_vehicle.rotation
        next_obs = {
            # "lidar": np.array(lidar_bin),
            "control": np.array([throttle, steer, brake]),
            "acceleration": to_array(self.sim.ego_vehicle.acceleration),
            "angular_veolcity": to_array(self.sim.ego_vehicle.angular_velocity),
            "location": to_array(self.sim.ego_vehicle.location),
            "rotation": to_array(rotation),
            "forward_vector": to_array(rotation.get_forward_vector()),
            "veolcity": to_array(self.sim.ego_vehicle.velocity),
            "target_location": to_array(self.sim.target_location),
        }

        done = self.count >= self.config.max_steps
        if done:
            print(
                f"Episode success: I've reached the episode horizon "
                f"({self.config.max_steps})."
            )

        info = {
            **{f"reward_{key}": value for key, value in reward_dict.items()},
            **{f"done_{key}": value for key, value in done_dict.items()},
            "control_repeat": self.config.frame_skip,
            "weather": self.config.weather,
            "settings_map": self.sim.world.map.name,
            "settings_multiagent": self.config.multiagent,
            "traffic_lights_color": "UNLABELED",
            "reward": reward,
            "expert_action": np.array(
                [
                    expert_action.throttle - expert_action.brake,
                    expert_action.steer,
                ],
                dtype=np.float64,
            ),
            "success": done_dict['dist_done'],
            "is_success": done_dict['dist_done'],
            "percentage": np.round(1 - (self.min_reward / self.max_reward), 3),
        }
        if self.mode == "iter" or self.mode == "distill":
            info['resource'] = np.sum(self.num_resource)
            info['num_resource'] = self.num_resource

        next_obs_sensor = np.hstack(
            [value for key, value in next_obs.items() if key != "image"]
        )
        state = {
                "obs": next_obs_sensor,
                "image": self.sim.ego_vehicle.camera.image,
                "task": self.sim.route_manager.route_selector.get_route_task(),
            }

        if self.image_save_mode:
            # def process_image(image):
            #     array = np.frombuffer(image.raw_data, dtype=np.dtype("uint8"))
            #     array = np.reshape(array, (image.height, image.width, 4))
            #     array = array[:, :, :3]  # BGRA -> BGR
            #
            #     unique_ids = np.unique(array.reshape(-1, array.shape[2]), axis=0)
            #     print("Unique Object IDs in the image:", unique_ids)
            #     return unique_ids
            # self.sim.ego_vehicle.seg_camera.listen(lambda image: process_image(image))

            img = self.sim.ego_vehicle.large_camera.image
            cv2.imwrite(f'./carla_image/task_{self.sim.route_manager.route_selector.get_route_idx()}_{self.count}.png', img)


            # img = self.sim.ego_vehicle.camera.image
            # cv2.imwrite(f'./carla_image/task_{self.sim.route_manager.route_selector.get_route_idx()}_{self.count}.png', img)

        if self.dynamic_velocity:
            state["velocity"] = self.agent.get_speed()
            state["target_velocity"] = self.target_velocity
            info['velocity_reward'] = velocity_loss

            self.velocity_reset_count += 1
            self.velocity_reset_count %= 100
            # if self.velocity_reset_count % 100 == 0:
            #     self.target_velocity = np.random.uniform(self.throt_min, self.throt_max)
            #     self.agent._local_planner._target_speed = self.target_velocity

        if self.mode in ["ours", "iter", "distill", "ddnn"]:
            state["module_select"] = self.infer_module_select(state)
        elif self.mode == "slimmable":
            state["column"] = self.num_resource
        elif self.mode == "rlaa":
            state["policy_restrict"] = self.num_resource


        if done or any(done_dict.values()) and self.eval_mode:
            info['percentage'] = np.round(1 - ((self.min_reward - 15) / self.max_reward), 3)

        return (
            state,
            reward,
            done or any(done_dict.values()),
            info,
        )

    def reset(self, get_info: bool = False):
        # if self.agent is not None:
        #     print("Expert action before reset", self.target_velocity, self.agent.get_speed(), self.last_expert_action)
        self.reset_simulator()
        self.agent = RoamingAgent(
            self.sim.ego_vehicle.carla,
            follow_traffic_lights=self.config.lights,
        )
        self.agent._local_planner.set_global_plan(
            self.sim.route_manager.waypoints
        )

        reward, reward_dict, done_dict = self.goal_reaching_reward()
        self.min_reward, self.max_reward = reward_dict['dist'], reward_dict['dist']


        self.count = 0

        obs, _, _, _ = self.step()

        if self.new_key is not None:
            if self.mode != 'ours':
                self.num_resource = np.zeros(self.res_num_obs)
                if not self.fixed:
                    # res = np.random.randint(self.res_num_obs)
                    res = 4
                else:
                    res = self.fixed_resource
                for i in range(res + 1):
                    self.num_resource[i] = 1

                if self.dynamic_velocity:
                    self.velocity_reset_count = 0
                    # self.target_velocity = self.throt_min + (self.throt_max - self.throt_min) * (res + 1) / self.res_num_obs

            else:
                self.num_resource = np.ones(self.res_num_obs)

            if self.new_key == 'module_select':
                if self.mode == 'ddnn':
                    del obs['module_select']
                new_value = self.infer_module_select(obs)
            else:
                new_value = self.num_resource

            obs[self.new_key] = new_value

            if self.eval_mode:
                self.eval_log.append(np.sum(new_value))

        if self.dynamic_velocity:
            self.agent._local_planner.dynamic_velocity = True
            self.agent._local_planner._vehicle_controller.max_throt = 2.0
            self.agent._local_planner._vehicle_controller.max_brake = 1.0
            self.velocity_reset_count = 0
            # self.target_velocity = np.random.uniform(self.throt_min, self.throt_max)
            if self.mode == 'ours':
                self.target_velocity = np.random.uniform(self.throt_min, self.throt_max)
            else:
                self.target_velocity = self.throt_max - (self.throt_max - self.throt_min) * res / self.res_num_obs
            self.agent._local_planner._target_speed = self.target_velocity
            obs["velocity"] = self.agent.get_speed()
            obs["target_velocity"] = self.target_velocity

        if get_info:
            info = dict()
            info['resource'] = np.sum(self.num_resource)
            info['num_resource'] = self.num_resource
            return obs, info
        return obs

    def infer_module_select(self, obs):  # need to distill -> one step

        use_resource = int(np.sum(self.num_resource))
        if self.routing_model is not None:
            if self.mode == 'iter':
                module_select = np.zeros(self.res_num_obs)
                routing_obs = copy.deepcopy(obs)
                routing_obs['module_select'] = module_select
                routing_obs['mask'] = np.copy(module_select)
                routing_obs['num_resource'] = self.num_resource

                for _ in range(use_resource):
                    action, _ = self.routing_model.predict(routing_obs)
                    module_select[int(action)] = 1
                    mask = np.copy(module_select)

                    next_routing_obs = copy.deepcopy(routing_obs)
                    next_routing_obs['module_select'] = module_select
                    next_routing_obs['mask'] = mask
                    routing_obs = next_routing_obs

            elif self.mode == "distill":
                module_select = np.zeros(self.res_num_obs)
                routing_obs = copy.deepcopy(obs)
                if 'module_select' in routing_obs.keys():
                    del routing_obs['module_select']
                routing_obs['num_resource'] = self.num_resource

                action, _ = self.routing_model.predict(routing_obs)
                if np.ndim(action) != 1:
                    action = action[0]
                for idx in np.argsort(action)[-use_resource:]:
                    module_select[idx] = 1

            elif self.mode == 'ddnn':
                routing_obs = copy.deepcopy(obs)

                action, _ = self.routing_model.predict(routing_obs)
                module_select = action.reshape(4, 4, 2)
                module_select[..., 0] += self.ddnn_bias
                module_select = module_select.argmax(-1)
                module_select = module_select.reshape(self.res_max)
            else:
                module_select = np.ones(16)
        else:
            module_select = np.ones(self.res_num_obs)
        return module_select

    def select_route(self, idx: int = 0):
        self.sim.route_manager.select_route_by_idx(idx)

    def set_target_velocity(self, velocity):
        if self.dynamic_velocity:
            self.throt_max = velocity
            self.throt_min = velocity

    def get_velocity(self):
        return self.agent.get_speed()

    def get_last_inf_time(self):
        return 0

    def set_resource(self, num_resource):
        self.fixed_resource = num_resource
        self.num_resource = np.zeros(self.res_num_obs)
        res = self.fixed_resource
        for i in range(res + 1):
            self.num_resource[i] = 1
        self.fixed = True

    def eval_reset(self):
        ret = self.eval_log
        self.eval_log = []
        self.velocity_success_log = []
        self.velocity_log = []
        return ret

    def texture_setter(self):
        self.sim.texture_setter()





def behavior_cloning(config: ExperimentConfigs, mode: str = None, port: int = None, dynamic_velocity: bool = False):
    """Behavior cloning experiment.
    
    Args:
        config (ExperimentConfigs): Experiment configs.
    """

    data_path = config.data_path

    env = BehaviorCloningCarlaEnvironment(config, mode=mode, port=port, dynamic_velocity=dynamic_velocity)
    # env.texture_setter()

    if mode == "rlaa":
        sub_policy_size = 32

        layer_idx = 1
        sub_policy_kwags = {
            'net_arch': [sub_policy_size],
            'features_extractor_class': SensorSubPolicyMultiInputCNN,
            'features_extractor_kwargs': {
                'layer_length': layer_idx,
                'output_dim': sub_policy_size
            },
        }

        model = BC(
            policy=BC_MultiInputPolicy,  # type: ignore
            env=env,
            verbose=1,
            gradient_steps=1,
            buffer_size=30000,
            train_freq=1,
            batch_size=128,
            learning_rate=1e-3,
            tensorboard_log="log/BC/RLAA_{}".format(layer_idx),
            policy_kwargs=sub_policy_kwags,
            without_exploration=False,
        )

        path = './models/CARLA/CARLA/'
        check_callback = jax_CheckpointCallback(save_freq=30_000, name_prefix="ModMA", save_path=path, )

    elif mode == "slimmable":
        size = 32


        policy_kwargs = {'net_arch': [size],
                         'features_extractor_class': SlimmableNetwork,
                         'features_extractor_kwargs': {'module_select_cond': True,
                                                       'task_num': env.sim.route_manager.route_selector.route_len,
                                                       'width_mult_list': [0.125, 0.25, 0.5, 0.75, 1.0],
                                                       'module_arch': [
                                                           {'features': [size], 'kernel_sizes': [3], 'strides': [1]},
                                                           {'features': [size], 'kernel_sizes': [3], 'strides': [1]},
                                                           {'features': [size * 2], 'kernel_sizes': [3], 'strides': [1]},
                                                           {'features': [size * 2], 'kernel_sizes': [3], 'strides': [1]},
                                                           {'features': [size * 3], 'kernel_sizes': [3], 'strides': [1]},
                                                           {'features': [size * 3], 'kernel_sizes': [3], 'strides': [1]}],
                                                       },
                         'use_sde': False,
                         }

        # model = torch_BC(
        #     policy=torch_MultiInputPolicy,  # type: ignore
        #     env=env,
        #     verbose=1,
        #     gradient_steps=1,
        #     buffer_size=10000,
        #     replay_buffer_class=torch_BCReplayBuffer,
        #     learning_starts=100,
        #     train_freq=1,
        #     batch_size=120,
        #     learning_rate=3e-4,
        #     tensorboard_log="log/BC",
        #     policy_kwargs=policy_kwargs,
        # )
        # only_actor_model = torch_BC.load('./models/CARLA/DSNet/32_slimmable.zip', env=env,)
        # model.policy.actor = only_actor_model.policy.actor
        # model._create_aliases()

        model = torch_BC.load('./models/CARLA/DSNet/32_slimmable.zip', env=env,)

        # model.set_restrict_action()
        # model.actor_original.features_extractor.set_full()


        # model = torch_BC(
        #     policy=torch_MultiInputPolicy,  # type: ignore
        #     env=env,
        #     verbose=1,
        #     gradient_steps=1,
        #     buffer_size=10000,
        #     replay_buffer_class=torch_BCReplayBuffer,
        #     learning_starts=100,
        #     train_freq=1,
        #     batch_size=120,
        #     learning_rate=3e-4,
        #     tensorboard_log="log/BC",
        #     policy_kwargs=policy_kwargs,
        # )
        # env.set_resource(4)



        path = './models/CARLA/DSNet/'

        check_callback = CheckpointCallback(save_freq=30_000, name_prefix="DSNet", save_path=path, )


    elif mode == "ddnn":
        size = 64

        coef = 0.3
        bias = 0.7
        env.ddnn_bias = bias

        model = BC.load('./models/CARLA/ModMA/32_sm_model.zip', tensorboard_log="./log/BC/DDNN/{}_{}/".format(coef, bias), env=env, gradient_steps=5, batch_size=120, train_freq=120,)
        model._load_policy()
        model.off_restrict_action()

        # model = BC.load(f'./models/CARLA/DDNN/{coef}_{bias}/32_sm_model.zip', env=env)


        routing_envs = DDNNRouting(env, 4, 4, bias=bias, coef=coef, dynamic_model=model.actor, simple_sample=True)

        iterative_policy_kwargs = {
            'features_extractor_class': SensorMultiInputCNN,
            'features_extractor_kwargs': {'feature_dim': size * 2 + routing_envs.other_dims, 'net_arch': [size, size * 2]},
        }

        iterative_model = SAC(SAC_MultiInputPolicy, env=routing_envs,
                              tensorboard_log='./log/DDNN_conv/routing/{}_{}/'.format(coef, bias), gamma=1,
                              policy_kwargs=iterative_policy_kwargs,
                              learning_rate=1e-4, replay_buffer_class=DictReplayBuffer, buffer_size=126,
                              learning_starts=126, train_freq=120, batch_size=120)
        # iterative_model = SAC.load(f'./models/CARLA/DDNN/{coef}_{bias}/32_routing_model.zip', env=routing_envs)

        env.routing_model = iterative_model
        routing_envs = DDNNRouting(env, 4, 4, bias=bias, coef=coef, dynamic_model=model.actor, simple_sample=True)
        iterative_model.env = SAC._wrap_env(routing_envs)
        model.env = BC._wrap_env(env)

        path = "./models/CARLA/DDNN/{}_{}".format(coef, bias)

        callback = RoutingCallback(iterative_model, save_freq=30_000, save_path=path)

        tot_timestep = 3_000_000
        iter_step = 3000
        iters = tot_timestep // iter_step

        iterative_model.learn(total_timesteps=iter_step, log_interval=5000, reset_num_timesteps=True)
        model.learn(total_timesteps=iter_step, log_interval=1, reset_num_timesteps=True, callback=callback, )

        for idx in range(iters):
            iterative_model.learn(total_timesteps=iter_step, log_interval=5000, reset_num_timesteps=False)
            model.learn(total_timesteps=iter_step, log_interval=1, reset_num_timesteps=False, callback=callback, )

    elif mode == "drnet":
        size = 32

        coef = 10
        bias = 10

        cell_num = 2
        input_num = 1
        internode_num = 4
        branch_num = 2

        if os.path.exists("./connection_{}_{}_{}_{}.npy".format(cell_num, input_num, internode_num, branch_num)):
            print("Connection File Exists")
            with open("./connection_{}_{}_{}_{}.npy".format(cell_num, input_num, internode_num, branch_num), 'rb') as f:
                connection_list = np.load(f)
        else:
            print("Connection File not Exists")
            connection_list = []  # (cell, internode, input num)
            for i in range(cell_num):
                cell_connection_list = []
                for j in range(1, 1 + internode_num):
                    idx_list = np.arange(j)
                    print(i, j, idx_list)
                    first_idx = 0
                    if j != 1:
                        idx_list = np.delete(idx_list, first_idx)
                    second_idx = random.choice(idx_list)
                    cell_connection_list.append([first_idx, second_idx])

        policy_kwargs = {'net_arch': [size],
                         'features_extractor_kwargs': {'connection_list': connection_list,
                                                       'task_num': env.num_task,
                                                       'input_num': input_num,
                                                       'device': 'cuda',
                                                       'bias': bias,
                                                       'cell_num': cell_num,
                                                       'internode_num': internode_num,
                                                       'training': True,
                                                       'module_arch': [
                                                           {'features': [size], 'kernel_sizes': [3], 'strides': [1]},
                                                           {'features': [size * 2], 'kernel_sizes': [3], 'strides': [1]},
                                                       ],
                                                       },
                         'features_extractor_class': DRNet_Network, }

        path = './models/CARLA/DRNet/{}_{}'.format(coef, bias)


        if not os.path.isdir(path):
            os.mkdir(path)

        check_callback = CheckpointCallback(save_freq=30_000, name_prefix="DRNet", save_path='./models/CARLA/DRNet/{}_{}'.format(coef, bias), )

        model = torch_BC(
            policy=torch_MultiInputPolicy,  # type: ignore
            env=env,
            verbose=1,
            gradient_steps=1,
            buffer_size=30000,
            replay_buffer_class=torch_BCReplayBuffer,
            learning_starts=100,
            train_freq=1,
            batch_size=120,
            learning_rate=3e-4,
            tensorboard_log="log/BC/DRNet/{}_{}".format(coef, bias),
            effi_coef=coef,
            policy_kwargs=policy_kwargs,
        )

    elif mode == "iter":
        size = 64
        iter_size = 64

        if dynamic_velocity:
            log_path = './log/Dynamic/'
            model = SoftModularization.load('./models/CARLA/ModMA/dynamic_sm_model.zip', tensorboard_log=os.path.join(log_path, 'sm'), env=env, gradient_steps=1, batch_size=120, train_freq=120, )
            # model = SoftModularization.load('./models/CARLA/ModMA/dynamic_sm.zip', tensorboard_log=os.path.join(log_path, 'sm'), env=env, gradient_steps=1, batch_size=120, train_freq=120, )
            print(model.learning_rate)
        else:
            log_path = './log/BC/'
            model = BC.load('./models/CARLA/ModMA/32_sm_model.zip', tensorboard_log=os.path.join(log_path, 'sm'), env=env, gradient_steps=1, batch_size=120, train_freq=120,)
            model.gradient_steps = 1

        routing_envs = DynamicSequentialRouting(env, 4, 4, model.actor, simple_sample=not dynamic_velocity)

        iterative_policy_kwargs = {
            'net_arch': [iter_size, iter_size],
            'features_extractor_class': SensorMultiInputCNN,
            'features_extractor_kwargs': {'feature_dim': iter_size * 2 + routing_envs.other_dims, 'net_arch': [iter_size, iter_size * 2]},
        }

        # iterative_model = REINFORCE.load('./models/CARLA/ModMA/dynamic_iter.zip', env=routing_envs)

        # iterative_model = REINFORCE.load('./models/CARLA/ModMA/32_routing_model.zip', env=routing_envs, tensorboard_log="./log/BC/routing",
        #                             policy_kwargs=iterative_policy_kwargs, gamma=1,
        #                             verbose=1, gradient_steps=1, batch_size=120, train_freq=120, learning_rate=1e-4,
        #                             replay_buffer_class=DictReplayBuffer)
        # iterative_model.learning_rate = 5e-5
        # iterative_model._setup_lr_schedule()

        iterative_model = REINFORCE(CnnPolicy, routing_envs, tensorboard_log=os.path.join(log_path, 'routing'),
                                    policy_kwargs=iterative_policy_kwargs, gamma=1,
                                    verbose=1, gradient_steps=1, buffer_size=128, batch_size=120, train_freq=120, learning_starts=120, learning_rate=3e-4,
                                    replay_buffer_class=DictReplayBuffer)

        env.routing_model = iterative_model
        routing_envs = DynamicSequentialRouting(env, 4, 4, model.actor, simple_sample=not dynamic_velocity)
        iterative_model.env = REINFORCE._wrap_env(routing_envs)

        if dynamic_velocity:
            model.env = SoftModularization._wrap_env(env)
        else:
            model.env = BC._wrap_env(env)

        path = './models/CARLA/ModMA/'

        callback = RoutingCallback(iterative_model, save_freq=30_000, save_path=path)

        tot_timestep = 3_000_000
        iter_step = 3000
        iters = tot_timestep // iter_step

        model.set_restrict_action()

        iterative_model.learn(total_timesteps=iter_step, log_interval=100, reset_num_timesteps=True)
        model.learn(total_timesteps=iter_step, log_interval=1, reset_num_timesteps=True, callback=callback, )

        for idx in range(iters):
            iterative_model.learn(total_timesteps=iter_step, log_interval=100, reset_num_timesteps=False)
            model.learn(total_timesteps=iter_step, log_interval=1, reset_num_timesteps=False, callback=callback, )

    elif mode == 'distill':
        size = 16
        distill_size = 64
        path = './models/CARLA/ModMA'

        sm_model = BC.load(os.path.join(path, '32_sm_model.zip'), env=env)

        routing_envs = DynamicSequentialRouting(env, 4, 4, sm_model.actor, simple_sample=not dynamic_velocity)

        iterative_model = REINFORCE.load(os.path.join(path, '32_routing_model.zip'), routing_envs)

        reset_num_timesteps = True
        routing_envs = DynamicSequentialRouting(env, 4, 4, sm_model.actor, simple_sample=not dynamic_velocity)
        iterative_model.env = REINFORCE._wrap_env(routing_envs)
        sm_model.env = SoftModularization._wrap_env(env)
        sm_model.n_envs = 1

        distill_env = DynamicDistillRouting(env, 4, 4, sm_model.actor, iterative_model, routing_envs, simple_sample=not dynamic_velocity)

        distill_policy_kwargs = {
            'net_arch': [distill_size, distill_size],
            'features_extractor_class': SensorMultiInputCNN,
            'features_extractor_kwargs': {'feature_dim': distill_size + distill_env.other_dims, 'net_arch': [distill_size, distill_size * 2]},
        }

        # model = Distillation.load('./models/CARLA/Distill/distill_5.zip', env=distill_env, policy_kwargs=distill_policy_kwargs)
        # model.learning_rate = 5e-5
        # model.train_freq = 1
        # model._convert_train_freq()
        # model._setup_lr_schedule()

        model = Distillation(SAC_MultiInputPolicy, distill_env, tensorboard_log='./log/BC/distill', verbose=1,
                             learning_rate=5e-5, batch_size=120, policy_kwargs=distill_policy_kwargs,
                             replay_buffer_class=TeacherReplaybuffer, learning_starts=2000, gradient_steps=1,
                             train_freq=12, buffer_size=30000, )

        path = './models/CARLA/Distill/'
        check_callback = jax_CheckpointCallback(save_freq=30_000, save_path=path, )

    elif dynamic_velocity:
        module_num = 4
        layer_num = 4

        size = 32

        module_arch = [{'features': [size], 'kernel_sizes': [3], 'strides': [1]}] * (layer_num // 2) + \
                      [{'features': [size * 2], 'kernel_sizes': [3], 'strides': [1]}] * (layer_num // 2)
        policy_kwargs = {
            'net_arch': [size, size],
            'features_extractor_kwargs': {'module_select_cond': True,
                                          'module_arch': module_arch,
                                          'net_arch': [module_num] * layer_num + [1]
                                          },
            'features_extractor_class': SensorConvSoftModule}


        # model = SoftModularization(
        #     policy=MultiInputPolicy,  # type: ignore
        #     env=env,
        #     verbose=1,
        #     learning_starts=120,
        #     gradient_steps=1,
        #     replay_buffer_class=TeacherReplaybuffer,
        #     buffer_size=30000,
        #     train_freq=1,
        #     gamma=0.99,
        #     num_tasks=12,
        #     batch_size=120,
        #     learning_rate=3e-4,
        #     tensorboard_log="log/Dynamic",
        #     policy_kwargs=policy_kwargs,
        #     without_exploration=False,
        #     bc_mode=True,
        # )
        model = SoftModularization.load('./models/CARLA/ModMA/dynamic_sm_model.zip', env=env, tensorboard_log="log/Dynamic",)
        path = './models/CARLA/ModMA/'
        check_callback = jax_CheckpointCallback(save_freq=30_000, name_prefix="ModMA", save_path=path, )

    else:
        module_num = 4
        layer_num = 4

        size = 32

        module_arch = [{'features': [size], 'kernel_sizes': [3], 'strides': [1]}] * (layer_num // 2) + \
                      [{'features': [size * 2], 'kernel_sizes': [3], 'strides': [1]}] * (layer_num // 2)
        policy_kwargs = {
            'net_arch': [size, size],
            'features_extractor_kwargs': {'module_select_cond': True,
                                          'module_arch': module_arch,
                                          'net_arch': [module_num] * layer_num + [1]
                                          },
            'features_extractor_class': SensorConvSoftModule}

        # model = BC.load('./models/CARLA/ModMA/32_sm_model.zip', env=env, )
        # model.off_restrict_action()
        # model = BC.load('./models/ours_route_128_0.zip', env=env,)

        model = BC(
            policy=BC_MultiInputPolicy,  # type: ignore
            env=env,
            verbose=1,
            gradient_steps=1,
            buffer_size=30000,
            train_freq=1,
            batch_size=128,
            learning_rate=3e-4,
            tensorboard_log="log/BC",
            policy_kwargs=policy_kwargs,
            without_exploration=False,
        )
        path = './models/CARLA/ModMA/'
        check_callback = jax_CheckpointCallback(save_freq=30_000, name_prefix="ModMA", save_path=path, )


    if mode != "iter" and mode != "ddnn":

        model.learn(total_timesteps=10000, log_interval=1, callback=check_callback)
        for i in range(500):
            model.learn(total_timesteps=10000, log_interval=1, reset_num_timesteps=False, callback=check_callback)
