import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import os
import time

from agent import Agent
import utils

import hydra


class SACAgent(Agent):
    """SAC algorithm."""
    def __init__(self, obs_dim, action_dim, action_range, device, critic_cfg,
                 actor_cfg, discount, init_temperature, alpha_lr, alpha_betas,
                 actor_lr, actor_betas, actor_update_frequency, critic_lr,
                 critic_betas, critic_tau, critic_target_update_frequency,
                 batch_size, learnable_temperature):
        super().__init__()

        self.action_range = action_range
        self.device = torch.device(device)
        self.discount = discount
        self.critic_tau = critic_tau
        self.actor_update_frequency = actor_update_frequency
        self.critic_target_update_frequency = critic_target_update_frequency
        self.batch_size = batch_size
        self.learnable_temperature = learnable_temperature

        self.critic = hydra.utils.instantiate(critic_cfg).to(self.device)
        self.critic_target = hydra.utils.instantiate(critic_cfg).to(
            self.device)
        self.critic_target.load_state_dict(self.critic.state_dict())

        self.actor = hydra.utils.instantiate(actor_cfg).to(self.device)

        self.log_alpha = torch.tensor(np.log(init_temperature)).to(self.device)
        self.log_alpha.requires_grad = True
        # set target entropy to -|A|
        self.target_entropy = -action_dim

        # optimizers
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),
                                                lr=actor_lr,
                                                betas=actor_betas)

        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),
                                                 lr=critic_lr,
                                                 betas=critic_betas)

        self.log_alpha_optimizer = torch.optim.Adam([self.log_alpha],
                                                    lr=alpha_lr,
                                                    betas=alpha_betas)

        self.train()
        self.critic_target.train()

    def train(self, training=True):
        self.training = training
        self.actor.train(training)
        self.critic.train(training)

    @property
    def alpha(self):
        return self.log_alpha.exp()

    def act(self, obs, sample=False):
        obs = torch.FloatTensor(obs).to(self.device)
        obs = obs.unsqueeze(0)
        dist = self.actor(obs)
        action = dist.sample() if sample else dist.mean
        action = action.clamp(*self.action_range)
        assert action.ndim == 2 and action.shape[0] == 1
        return utils.to_np(action[0])

    def update_critic(self, obs, action, reward, next_obs, not_done, logger,
                      step):
        dist = self.actor(next_obs)
        next_action = dist.rsample()
        log_prob = dist.log_prob(next_action).sum(-1, keepdim=True)
        target_Q1, target_Q2 = self.critic_target(next_obs, next_action)
        target_V = torch.min(target_Q1,
                             target_Q2) - self.alpha.detach() * log_prob
        target_Q = reward + (not_done * self.discount * target_V)
        target_Q = target_Q.detach()

        # get current Q estimates
        current_Q1, current_Q2 = self.critic(obs, action)
        # if step % 200 == 0:
        #     print("Q:", torch.mean(current_Q1)) # change back asap!!
        critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(
            current_Q2, target_Q) # added float here change back asap!!!
        # if step % 200 == 0:
        #     print("critic loss", critic_loss) # change back asap!!
        logger.log('train_critic/loss', critic_loss, step)

        # Optimize the critic
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        self.critic.log(logger, step)

    def update_actor_and_alpha(self, obs, logger, step):
        dist = self.actor(obs)
        action = dist.rsample()
        log_prob = dist.log_prob(action).sum(-1, keepdim=True)
        actor_Q1, actor_Q2 = self.critic(obs, action)

        actor_Q = torch.min(actor_Q1, actor_Q2)
        actor_loss = (self.alpha.detach() * log_prob - actor_Q).mean()

        logger.log('train_actor/loss', actor_loss, step)
        logger.log('train_actor/target_entropy', self.target_entropy, step)
        logger.log('train_actor/entropy', -log_prob.mean(), step)

        # optimize the actor
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        self.actor.log(logger, step)

        if self.learnable_temperature:
            self.log_alpha_optimizer.zero_grad()
            alpha_loss = (self.alpha *
                          (-log_prob - self.target_entropy).detach()).mean()
            logger.log('train_alpha/loss', alpha_loss, step)
            logger.log('train_alpha/value', self.alpha, step)
            alpha_loss.backward()
            self.log_alpha_optimizer.step()

    def update(self, replay_buffer, logger, step):
        #print("\n In update:")
        cur_time = time.time()
        obs, action, reward, next_obs, not_done, not_done_no_max = replay_buffer.sample(
            self.batch_size)
        #print("sampling time:", (time.time() - cur_time) * 1000)
        cur_time = time.time()


        logger.log('train/batch_reward', reward.mean(), step)

        cur_time = time.time()

        self.update_critic(obs, action, reward, next_obs, not_done_no_max,
                           logger, step)

        #print("update critic time:", (time.time() - cur_time) * 1000)
        cur_time = time.time()

        if step % self.actor_update_frequency == 0:
            self.update_actor_and_alpha(obs, logger, step)
            #print("update actor time:", (time.time() - cur_time) * 1000)
            cur_time = time.time()

        if step % self.critic_target_update_frequency == 0:
            utils.soft_update_params(self.critic, self.critic_target,
                                     self.critic_tau)
            #print("soft update time:", (time.time() - cur_time) * 1000)
            cur_time = time.time()

    def save(self, path):
        os.makedirs(path, exist_ok=True)
        torch.save(self.critic.state_dict(), path + '/critic')
        torch.save(self.actor.state_dict(), path + '/actor')

    def load(self, path):
        self.critic.load_state_dict(torch.load(path + '/critic'))
        self.actor.load_state_dict(torch.load(path + '/actor'))

# to have a shared encoder
class SACAgent_shared(Agent):
    """SAC algorithm."""
    def __init__(self, obs_dim, action_dim, action_range, device, critic_cfg,
                 actor_cfg, discount, init_temperature, alpha_lr, alpha_betas,
                 actor_lr, actor_betas, actor_update_frequency, critic_lr,
                 critic_betas, critic_tau, critic_target_update_frequency,
                 batch_size, learnable_temperature):
        super().__init__()

        self.action_range = action_range
        self.device = torch.device(device)
        self.discount = discount
        self.critic_tau = critic_tau
        self.actor_update_frequency = actor_update_frequency
        self.critic_target_update_frequency = critic_target_update_frequency
        self.batch_size = batch_size
        self.learnable_temperature = learnable_temperature

        self.critic = hydra.utils.instantiate(critic_cfg).to(self.device)
        self.critic_target = hydra.utils.instantiate(critic_cfg).to(
            self.device)
        self.critic_target.load_state_dict(self.critic.state_dict())

        self.actor = hydra.utils.instantiate(actor_cfg).to(self.device)

        self.log_alpha = torch.tensor(np.log(init_temperature)).to(self.device)
        self.log_alpha.requires_grad = True
        # set target entropy to -|A|
        self.target_entropy = -action_dim

        # optimizers
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),
                                                lr=actor_lr,
                                                betas=actor_betas)

        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),
                                                 lr=critic_lr,
                                                 betas=critic_betas)

        self.log_alpha_optimizer = torch.optim.Adam([self.log_alpha],
                                                    lr=alpha_lr,
                                                    betas=alpha_betas)

        self.critic_target.train()

    def add_encoder(self, encoder, encoder_target):
        self.encoder = encoder
        self.encoder_target = encoder_target
        self.encoder_target.load_state_dict(self.encoder.state_dict())
        self.encoder_optimizer = torch.optim.Adam(self.encoder.parameters(),
                                                  lr=1e-4,
                                                  betas=[0.9, 0.999]) #hardcode
        self.train()

    def train(self, training=True):
        self.training = training
        self.actor.train(training)
        self.critic.train(training)
        self.encoder.train(training)

    @property
    def alpha(self):
        return self.log_alpha.exp()

    # obs: (obs, ); delta_obses: (stack, obs); actions:(stack, act)
    def act(self, obs, delta_obses, actions, sample=False):
        obs = torch.FloatTensor(obs).to(self.device)
        delta_obses = torch.FloatTensor(delta_obses).to(self.device)
        actions = torch.FloatTensor(actions).to(self.device)
        obs, delta_obses, actions = obs.unsqueeze(0), delta_obses.unsqueeze(0), actions.unsqueeze(0)
        obs = self.encoder_forward(obs, delta_obses)
        dist = self.actor(obs)
        action = dist.sample() if sample else dist.mean
        action = action.clamp(*self.action_range)
        assert action.ndim == 2 and action.shape[0] == 1
        return utils.to_np(action[0])

    def update_critic(self, obs, action, reward, next_obs, not_done, logger,
                      step):
        dist = self.actor(next_obs)
        next_action = dist.rsample()
        log_prob = dist.log_prob(next_action).sum(-1, keepdim=True)
        target_Q1, target_Q2 = self.critic_target(next_obs, next_action)
        target_V = torch.min(target_Q1,
                             target_Q2) - self.alpha.detach() * log_prob
        target_Q = reward + (not_done * self.discount * target_V)
        target_Q = target_Q.detach()

        # get current Q estimates
        current_Q1, current_Q2 = self.critic(obs, action)
        critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(
            current_Q2, target_Q)
        logger.log('train_critic/loss', critic_loss, step)

        #print("before update:", self.encoder.trunk[0].weight[0])

        # Optimize the critic and encoder
        self.critic_optimizer.zero_grad()
        self.encoder_optimizer.zero_grad()
        if step > 5000 and step < 5005:
            print("updating encoder with critic loss")
        
        critic_loss.backward()

        self.critic_optimizer.step()
        self.encoder_optimizer.step()

        #print("after update:", self.encoder.trunk[0].weight[0])
        self.critic.log(logger, step)

    def update_actor_and_alpha(self, obs, logger, step):
        dist = self.actor(obs)
        action = dist.rsample()
        log_prob = dist.log_prob(action).sum(-1, keepdim=True)
        actor_Q1, actor_Q2 = self.critic(obs, action)

        actor_Q = torch.min(actor_Q1, actor_Q2)
        actor_loss = (self.alpha.detach() * log_prob - actor_Q).mean()

        logger.log('train_actor/loss', actor_loss, step)
        logger.log('train_actor/target_entropy', self.target_entropy, step)
        logger.log('train_actor/entropy', -log_prob.mean(), step)

        # optimize the actor
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        self.actor.log(logger, step)

        if self.learnable_temperature:
            self.log_alpha_optimizer.zero_grad()
            alpha_loss = (self.alpha *
                          (-log_prob - self.target_entropy).detach()).mean()
            logger.log('train_alpha/loss', alpha_loss, step)
            logger.log('train_alpha/value', self.alpha, step)
            alpha_loss.backward()
            self.log_alpha_optimizer.step()

    # could put inside encoder
    def encoder_forward(self, ori_obs, delta_obses, ema=False, detach=False):
        b = ori_obs.shape[0]
        x = delta_obses.reshape(b, -1)
        if ema:
            latent = self.encoder_target(x)
        else:
            latent = self.encoder(x)
        if detach:
            return torch.detach(torch.cat((ori_obs, x, latent), -1))
        else:
            return (torch.cat((ori_obs, x, latent), -1)) #(b, obs_dim * (stacknum + 1) + latent_dim) # append delta sac as well for now

    def update(self, replay_buffer, logger, step):
        ori_obs, delta_obses, actions, reward, ori_next_obs, delta_next_obses, next_actions, not_done, not_done_no_max  = replay_buffer.sample(
            self.batch_size)


        obs = self.encoder_forward(ori_obs, delta_obses)
        next_obs = self.encoder_forward(ori_next_obs, delta_next_obses, ema=True, detach=True)
        action = actions[:, -1]
        
        logger.log('train/batch_reward', reward.mean(), step)

        self.update_critic(obs, action, reward, next_obs, not_done_no_max,
                           logger, step)

        if step % self.actor_update_frequency == 0:
            obs = self.encoder_forward(ori_obs, delta_obses, detach=True) #only update encoder through critic loss
            self.update_actor_and_alpha(obs, logger, step)

        if step % self.critic_target_update_frequency == 0:
            utils.soft_update_params(self.critic, self.critic_target,
                                     self.critic_tau)
            utils.soft_update_params(self.encoder, self.encoder_target,
                                     self.critic_tau)

    def save(self, path):
        os.makedirs(path, exist_ok=True)
        torch.save(self.critic.state_dict(), path + '/critic')
        torch.save(self.actor.state_dict(), path + '/actor')
        torch.save(self.encoder.state_dict(), path + '/encoder')

    def load(self, path):
        self.critic.load_state_dict(torch.load(path + '/critic'))
        self.actor.load_state_dict(torch.load(path + '/actor'))
        self.encoder.load_state_dict(torch.load(path + '/encoder'))




# """ for shared encoder """        
# class SACAgent_shared(Agent):
#     """SAC algorithm."""
#     def __init__(self, obs_dim, action_dim, action_range, device, critic_cfg,
#                  actor_cfg, discount, init_temperature, alpha_lr, alpha_betas,
#                  actor_lr, actor_betas, actor_update_frequency, critic_lr,
#                  critic_betas, critic_tau, critic_target_update_frequency,
#                  batch_size, learnable_temperature):
#         super().__init__()

#         self.action_range = action_range
#         self.device = torch.device(device)
#         self.discount = discount
#         self.critic_tau = critic_tau
#         self.actor_update_frequency = actor_update_frequency
#         self.critic_target_update_frequency = critic_target_update_frequency
#         self.batch_size = batch_size
#         self.learnable_temperature = learnable_temperature

#         self.critic_cfg = critic_cfg
#         self.actor_cfg = actor_cfg

#         self.critic = hydra.utils.instantiate(critic_cfg).to(self.device)
#         self.actor = hydra.utils.instantiate(actor_cfg).to(self.device)

#         self.log_alpha = torch.tensor(np.log(init_temperature)).to(self.device)
#         self.log_alpha.requires_grad = True
#         # set target entropy to -|A|
#         self.target_entropy = -action_dim

#         self.log_alpha_optimizer = torch.optim.Adam([self.log_alpha],
#                                             lr=alpha_lr,
#                                             betas=alpha_betas)

#         self.actor_lr, self.actor_betas = actor_lr, actor_betas
#         self.critic_lr, self.critic_betas = critic_lr, critic_betas

#     def add_encoder(self, encoder):
#         # add encoder
#         self.actor.add_encoder(encoder)
#         self.critic.add_encoder(encoder)

#         # optimizers
#         self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),
#                                                 lr=self.actor_lr,
#                                                 betas=self.actor_betas)

#         self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),
#                                                  lr=self.critic_lr,
#                                                  betas=self.critic_betas)
        
#         self.critic_target = hydra.utils.instantiate(self.critic_cfg).to(self.device)
#         self.critic_target.load_state_dict(self.critic.state_dict())

#         self.train()
#         self.critic_target.train()
        

#     def train(self, training=True):
#         self.training = training
#         self.actor.train(training)
#         self.critic.train(training)

#     @property
#     def alpha(self):
#         return self.log_alpha.exp()

#     def act(self, obs, sample=False):
#         obs = torch.FloatTensor(obs).to(self.device)
#         obs = obs.unsqueeze(0)
#         dist = self.actor(obs)
#         action = dist.sample() if sample else dist.mean
#         action = action.clamp(*self.action_range)
#         assert action.ndim == 2 and action.shape[0] == 1
#         return utils.to_np(action[0])

#     def update_critic(self, obs, action, reward, next_obs, not_done, 
#                       logger, step):
#         dist = self.actor(next_obs)
#         next_action = dist.rsample()
#         log_prob = dist.log_prob(next_action).sum(-1, keepdim=True)
#         target_Q1, target_Q2 = self.critic_target(next_obs, next_action)
#         target_V = torch.min(target_Q1,
#                              target_Q2) - self.alpha.detach() * log_prob
#         target_Q = reward + (not_done * self.discount * target_V)
#         target_Q = target_Q.detach()

#         # get current Q estimates
#         current_Q1, current_Q2 = self.critic(obs, action)
#         critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(
#             current_Q2, target_Q)
#         logger.log('train_critic/loss', critic_loss, step)

#         # Optimize the critic
#         self.critic_optimizer.zero_grad()
#         critic_loss.backward()
#         self.critic_optimizer.step()

#         self.critic.log(logger, step)

#     def update_actor_and_alpha(self, obs, logger, step):
#         dist = self.actor(obs)
#         action = dist.rsample()
#         log_prob = dist.log_prob(action).sum(-1, keepdim=True)
#         actor_Q1, actor_Q2 = self.critic(obs, action)

#         actor_Q = torch.min(actor_Q1, actor_Q2)
#         actor_loss = (self.alpha.detach() * log_prob - actor_Q).mean()

#         logger.log('train_actor/loss', actor_loss, step)
#         logger.log('train_actor/target_entropy', self.target_entropy, step)
#         logger.log('train_actor/entropy', -log_prob.mean(), step)

#         # optimize the actor
#         self.actor_optimizer.zero_grad()
#         actor_loss.backward()
#         self.actor_optimizer.step()

#         self.actor.log(logger, step)

#         if self.learnable_temperature:
#             self.log_alpha_optimizer.zero_grad()
#             alpha_loss = (self.alpha *
#                           (-log_prob - self.target_entropy).detach()).mean()
#             logger.log('train_alpha/loss', alpha_loss, step)
#             logger.log('train_alpha/value', self.alpha, step)
#             alpha_loss.backward()
#             self.log_alpha_optimizer.step()

#     def update(self, replay_buffer, logger, step):
#         ori_obses, delta_obses, actions, reward, ori_next_obses, delta_next_obses, not_dones, not_dones_no_max  = replay_buffer.sample(
#             self.batch_size)

#         logger.log('train/batch_reward', reward.mean(), step)

#         self.update_critic(ori_obses, delta_obses, actions, reward, ori_next_obses, delta_next_obses, not_dones,
#                            logger, step)

#         if step % self.actor_update_frequency == 0:
#             self.update_actor_and_alpha(obs, logger, step)

#         if step % self.critic_target_update_frequency == 0:
#             utils.soft_update_params(self.critic, self.critic_target,
#                                      self.critic_tau)

#     def save(self, path):
#         os.makedirs(path, exist_ok=True)
#         torch.save(self.critic.state_dict(), path + '/critic')
#         torch.save(self.actor.state_dict(), path + '/actor')
#         torch.save(self.encoder.state_dict(), path + '/encoder')

#     def load(self, path):
#         self.critic.load_state_dict(torch.load(path + '/critic'))
#         self.actor.load_state_dict(torch.load(path + '/actor'))
#         self.encoder.load_state_dict(torch.load(path + '/actor'))
