# modified implementation of DDPG from D4RL repo
# Origiral source: https://github.com/Farama-Foundation/D4RL-Evaluations/tree/master

import os
import copy

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


class Actor(nn.Module):
	def __init__(self, state_dim, action_dim, hidden_dim, max_action, std=0.3):
		super(Actor, self).__init__()
		self.std = std
		self.max_action = max_action
		self.input_layer = nn.Sequential(
            nn.Linear(state_dim, hidden_dim[0]),
            nn.ReLU(),
        )
		self.hidden_layers = nn.Sequential(
            *[nn.Sequential(
                nn.Linear(hidden_dim[i], hidden_dim[i+1]),
                nn.ReLU()
            ) for i in range(len(hidden_dim)-1)],
            nn.Linear(hidden_dim[-1], action_dim)
        )
	
	def forward(self, state):
		embed = self.input_layer(state)
		mean = self.hidden_layers(embed)
		mean = torch.tanh(mean) * self.max_action
		eps = torch.randn_like(mean).to(state.device)
		return mean + eps * self.std
	
	def log_prob(self, state, action):
		embed = self.input_layer(state)
		mean = self.hidden_layers(embed)
		mean = torch.tanh(mean) * self.max_action
		normal = torch.distributions.Normal(mean, self.std)
		return normal.log_prob(action).sum(1)


class Critic(nn.Module):
	def __init__(self, state_dim, action_dim, hidden_dim):
		super(Critic, self).__init__()
		self.input_layer = nn.Sequential(
            nn.Linear(state_dim + action_dim, hidden_dim[0]),
            nn.ReLU(),
        )
		self.hidden_layers = nn.Sequential(
            *[nn.Sequential(
                nn.Linear(hidden_dim[i], hidden_dim[i+1]),
                nn.ReLU()
            ) for i in range(len(hidden_dim)-1)],
            nn.Linear(hidden_dim[-1], 1)
        )

	def forward(self, state, action):
		embed = self.input_layer(torch.cat([state, action], 1))
		value = self.hidden_layers(embed)
		return value


class DDPG(object):
	def __init__(self, agent_info, device):

		state_dim = agent_info['state_dim']
		action_dim = agent_info['action_dim']
		actor_hidden_dim = agent_info['actor_hidden_dim']
		critic_hidden_dim = agent_info['critic_hidden_dim']
		max_action = agent_info['max_action']
		
		self.device = device
		self.discount = agent_info['discount_factor']
		self.tau = agent_info['tau']

		learning_rate = agent_info['learning_rate']

		self.actor = Actor(state_dim, action_dim, actor_hidden_dim, max_action).to(device)
		self.actor_target = copy.deepcopy(self.actor)
		self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), learning_rate)

		self.critic = Critic(state_dim, action_dim, critic_hidden_dim).to(device)
		self.critic_target = copy.deepcopy(self.critic)
		self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), learning_rate)


	def select_action(self, state):
		state = torch.FloatTensor(state.reshape(1, -1)).to(self.device)
		return self.actor(state).cpu().data.numpy().flatten()


	def train(self, replay_buffer, batch_size):
		# Sample replay buffer 
		state, action, next_state, reward, not_done = replay_buffer.sample(batch_size)

		# Compute the target Q value
		target_Q = self.critic_target(next_state, self.actor_target(next_state))
		target_Q = reward + (not_done * self.discount * target_Q).detach()

		# Get current Q estimate
		current_Q = self.critic(state, action)

		# Compute critic loss
		critic_loss = F.mse_loss(current_Q, target_Q)

		# Optimize the critic
		self.critic_optimizer.zero_grad()
		critic_loss.backward()
		self.critic_optimizer.step()

		# Compute actor loss
		actor_loss = -self.critic(state, self.actor(state)).mean()
		
		# Optimize the actor 
		self.actor_optimizer.zero_grad()
		actor_loss.backward()
		self.actor_optimizer.step()

		# Update the frozen target models
		for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
			target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

		for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
			target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
		
		#return critic_loss.detach().cpu().numpy(), actor_loss.detach().cpu().numpy()


	def save(self, path, total_timestep):
		path = path + '/weight/{0}'.format(total_timestep)
		os.makedirs(path)
		torch.save(self.critic.state_dict(), path + "/critic")
		torch.save(self.critic_optimizer.state_dict(), path + "/critic_optimizer")
		
		torch.save(self.actor.state_dict(), path + "/actor")
		torch.save(self.actor_optimizer.state_dict(), path + "/actor_optimizer")


	def load(self, path, total_timestep):
		path = path + '/weight/{0}'.format(total_timestep)
		self.critic.load_state_dict(torch.load(path + "/critic"))
		self.critic_optimizer.load_state_dict(torch.load(path + "/critic_optimizer"))
		self.critic_target = copy.deepcopy(self.critic)

		self.actor.load_state_dict(torch.load(path + "/actor"))
		self.actor_optimizer.load_state_dict(torch.load(path + "/actor_optimizer"))
		self.actor_target = copy.deepcopy(self.actor)
		