from lpcmdp.algorithm.utils import *
import torch
from lpcmdp.algorithm.model import *
from torch.utils.data import DataLoader
import itertools
from tqdm import tqdm
from lpcmdp.env.FrozenLake import FrozenLakeEnv, FrozenLakeEnv_nocost
# from test import importance_test

def get_f_div_fn(f_type: str):
    """
    Returns a function that computes the provided f-divergence type.
    """
    f_fn = None
    f_prime_inv_fn = None

    if f_type == 'chi2':
        f_fn = lambda x: 0.5 * (x - 1)**2
        f_prime_inv_fn = lambda x: x + 1

    elif f_type == 'softchi':
        f_fn = lambda x: torch.where(x < 1,
                                     x * (torch.log(x + 1e-10) - 1) + 1, 0.5 *
                                     (x - 1)**2)
        f_prime_inv_fn = lambda x: torch.where(x < 0, torch.exp(x.clamp(max=0.0)), x + 1)

    elif f_type == 'kl':
        f_fn = lambda x: x * torch.log(x + 1e-10)
        f_prime_inv_fn = lambda x: torch.exp(x - 1)
    else:
        raise NotImplementedError('Not implemented f_fn:', f_type)

    return f_fn, f_prime_inv_fn

class QCritic(nn.Module):

    def __init__(self, obs_dim, act_dim, hidden_sizes, activation, num_q=1, mean=0.01, var=0.01):
        super().__init__()
        self.q_net = mlp([obs_dim + act_dim] + list(hidden_sizes) + [1], nn.ReLU, output_activation=activation)
        self.net_mean = mean
        self.net_var = var
        
        # for layer in self.q_net:
        #     if isinstance(layer, torch.nn.Linear):
        # #         # nn.init.xavier_normal_(layer.weight)
        #         nn.init.normal_(layer.weight, self.net_mean, self.net_var)

    def forward(self, obs, act=None):
        data = obs if act is None else torch.cat([obs, act], dim=-1)
        return self.q_net(data)

class Coptidice(nn.Module):
    def __init__(self, state_dim, init_state_propotion, c_hidden_sizes=[64, 64], gamma=0.9, alpha=0.5, cost_ub_epsilon=0.01, f_type='softchi', device='cpu'):
        super().__init__()
        self.state_dim = state_dim
        self.c_hidden_sizes = c_hidden_sizes
        self.gamma = gamma
        self.alpha = alpha
        self.cost_ub_epsilon = cost_ub_epsilon
        self.device = device
        
        self.tau = torch.ones(1, requires_grad=True, device=device)
        self.lmbda = torch.ones(1, requires_grad=True, device=device)
        self.nu_network = QCritic(self.state_dim, 0, self.c_hidden_sizes, nn.ReLU).to(self.device)
        self.chi_network = QCritic(self.state_dim, 0, self.c_hidden_sizes, nn.ReLU).to(self.device)
        
        self.f_fn, self.f_prime_inv_fn = get_f_div_fn(f_type)
        
        self.init_state_propotion = init_state_propotion
        self.qc_thres = 0
        
    def _optimal_w(self, observation, next_observation, rewards, costs, done):
        nu_s = self.nu_network(observation, None)
        nu_s_next = self.nu_network(next_observation, None)
        
        e_nu_lambda = rewards - self._lmbda.detach() * costs
        # print(done.shape, nu_s_next.shape, nu_s.shape)
        e_nu_lambda += self.gamma * (1.0 - done) * nu_s_next - nu_s
        
        w_sa = F.relu(self.f_prime_inv_fn(e_nu_lambda / self.alpha))
        return nu_s, nu_s_next, e_nu_lambda, w_sa
    
    def update(self, batch):
        observations, actions, rewards, costs, next_observations, done, hole, is_init = batch
        self._lmbda = F.softplus(self.lmbda)
        
        nu_s, nu_s_next, e_nu_lambda, w_sa = self._optimal_w(observations, next_observations, rewards, costs, done)
        nu_init = nu_s * is_init / self.init_state_propotion
        w_sa_no_grad = w_sa.detach()
        
        Df = self.f_fn(w_sa_no_grad).mean()
        
        self._tau = F.softplus(self.tau)
        batch_size = observations.shape[0]
        
        chi_s = self.chi_network(observations, None)
        chi_s_next = self.chi_network(next_observations, None)
        chi_init = chi_s * is_init / self.init_state_propotion
        
        ell = (1 - self.gamma) * chi_init + w_sa_no_grad * (costs + self.gamma * (1 - done) * chi_s_next - chi_s)
        logits = ell / self._tau.detach()
        weights = torch.softmax(logits, dim=0) * batch_size
        log_weights = torch.log_softmax(logits, dim=0) + np.log(batch_size)
        D_kl = (weights * log_weights - weights + 1).mean()
        
        weighted_c = (weights * w_sa_no_grad * costs).mean()

        chi_loss = (weights * ell).mean()
        self.chi_optim.zero_grad()
        chi_loss.backward(retain_graph=True)
        self.chi_optim.step()

        tau_loss = self._tau * (self.cost_ub_epsilon - D_kl.detach())
        self.tau_optim.zero_grad()
        tau_loss.backward()
        self.tau_optim.step()
        
        nu_loss = (1 - self.gamma) * nu_init.mean() + \
            (w_sa * e_nu_lambda - self.alpha * self.f_fn(w_sa)).mean()
        td_error = e_nu_lambda.pow(2).mean()

        self.nu_optim.zero_grad()
        nu_loss.backward(retain_graph=True)
        self.nu_optim.step()

        # 1.3 lambda loss
        lmbda_loss = self._lmbda * (self.qc_thres - weighted_c.detach())

        self.lmbda_optim.zero_grad()
        lmbda_loss.backward()
        self.lmbda_optim.step()
        
        with torch.no_grad():
            _, _, e_nu_lambda, w_sa = self._optimal_w(observations, next_observations, rewards, costs, done)
            
        return self.nu_network, w_sa.mean(), nu_loss.item(), td_error.item(), chi_loss.item(), tau_loss.item(), lmbda_loss.item(), self._lmbda.item(), self._tau.item()
            
        
        
    def setup_optimizers(self, critic_lr, scalar_lr):
        self.nu_optim = torch.optim.Adam(self.nu_network.parameters(), lr=critic_lr)
        self.chi_optim = torch.optim.Adam(self.chi_network.parameters(), lr=critic_lr)
        
        self.lmbda_optim = torch.optim.Adam([self.lmbda], lr=scalar_lr)
        self.tau_optim = torch.optim.Adam([self.tau], lr=scalar_lr)


class CoptidiceTrainer():

    def __init__(
            self,
            env,
            datacollector,
            epochs=80000,
            critic_lr=0.00001,
            scalar_lr=0.0001,
            device='cpu',
            test_ever = 1000,
            test_episode = 100,
            behavior_policy_style = 'real',
    ):
        self.env = env
        self.epochs = epochs
        self.device = device
        self.test_ever = test_ever
        self.test_episode = test_episode

        
        self.behav_policy_prob = datacollector.get_behave_policy_prob()
        self.real_behav_policy = datacollector.get_real_behave_policy()
        self.behavior_policy_style = behavior_policy_style
        offline_dataset = datacollector.get_onehot_encode_dataset()
        action = torch.tensor(offline_dataset['action'], dtype=torch.float32)
        observation = torch.tensor(offline_dataset['observation'], dtype=torch.float32)
        new_observation = torch.tensor(offline_dataset['new_observation'], dtype=torch.float32)
        cost = torch.tensor(offline_dataset['cost'], dtype=torch.float32).view(-1, 1)
        reward = torch.tensor(offline_dataset['reward'], dtype=torch.float32).view(-1, 1)
        goal = torch.tensor(offline_dataset['goal'], dtype=torch.float32).view(-1, 1)
        hole = torch.tensor(offline_dataset['hole'], dtype=torch.float32).view(-1, 1)
        is_init = torch.tensor(offline_dataset['is_init'], dtype=torch.float32).view(-1, 1)
        self.dataset = torch.utils.data.TensorDataset(observation, action, reward, cost, new_observation, goal, hole, is_init)
        init_state_propotion = datacollector.num_trajectories / len(self.dataset)
        self.model = Coptidice(state_dim=observation.shape[1], init_state_propotion=init_state_propotion)
        self.model.setup_optimizers(critic_lr, scalar_lr)

        self.logger = {}



    def optimal_w(self, model, lmbda):
        Encode_type = 'one_hot'
        _, f_prime_inv_fn = get_f_div_fn('softchi')
        alpha = 0.5
        
        obs_encode, acts_encode = self.env.state_action_onehot_encode()
        nxt_obs_encode, rewards, costs, dones, holes = torch.zeros_like(obs_encode), torch.zeros(obs_encode.shape[0], 1), torch.zeros(obs_encode.shape[0], 1), torch.zeros(obs_encode.shape[0], 1), torch.zeros(obs_encode.shape[0], 1)
        for i in range(obs_encode.shape[0]):
            obs = obs_encode[i].argmax()
            acts = acts_encode[i].argmax()
            nxt_obs, reward, cost, done, hole = self.env.step[obs][acts]
            nxt_obs_encode[i, nxt_obs] = nxt_obs
            rewards[i], costs[i], dones[i], holes[i] = reward, cost, done, hole
        nu_s = model(obs_encode, None)
        nu_s_next = model(nxt_obs_encode, None)
        
        e_nu_lambda = rewards - lmbda * costs
        e_nu_lambda += self.env.gamma * (1.0 - dones) * nu_s_next - nu_s
        w_s_a = F.relu(f_prime_inv_fn(e_nu_lambda / alpha))
        return w_s_a.detach().numpy()

    def train(self):
        obs_encode, acts_encode = self.env.state_action_onehot_encode()
        obs_encode, acts_encode = obs_encode.to(self.device), acts_encode.to(self.device)
        
        trainloader = itertools.cycle(DataLoader(self.dataset, batch_size=32, shuffle=True))
        trainloader_iter = iter(trainloader)
        test_reward, test_cost = [], []
        with tqdm(total=self.epochs) as pbar:
            for epoch in range(self.epochs):
                batch = next(trainloader_iter)
                batch = [b.to(self.device) for b in batch]

                
                w_model, w_sa, nu_loss, td_error, chi_loss, tau_loss, lmbda_loss, lmbda, tau = self.model.update(batch)

                if (epoch+1) % self.test_ever == 0:
                    # print(f"Epoch: {epoch}, Mean of w: {w_sa.item()}, Nu Loss: {nu_loss}, TD error: {td_error} Chi Loss: {chi_loss}, Tau Loss: {tau_loss}, Lambda Loss: {lmbda_loss}, Lambda: {lmbda}, Tau: {tau}")
                    w_model.eval()
                    w_s_a = self.optimal_w(w_model, lmbda)
                    policy = self.policy_extraction(w_s_a, behavior_policy_style=self.behavior_policy_style)
                    reward, cost = test(self.env, policy, self.test_episode)
                    # print(goalrate)
                    test_reward.append(reward)
                    test_cost.append(cost)
                    w_model.train()
                pbar.update(1)
        self.logger['test_reward'] = test_reward
        self.logger['test_cost'] = test_cost
        return self.policy_extraction(w_s_a, behavior_policy_style=self.behavior_policy_style)
    
    def policy_extraction(self, w_s_a, behavior_policy_style='real'):
        w = w_s_a.reshape((self.env.state_size, self.env.action_size))
        pi_s_a = np.zeros_like(w)
        if behavior_policy_style == 'real':
            behav_policy = self.real_behav_policy
        else:
            behav_policy = self.behav_policy_prob
        for i in range(w.shape[0]):
            for j in range(w.shape[1]):
                if np.sum(w[i, :] * behav_policy[i, :]) == 0:
                    pi_s_a[i][j] = 0.0
                else:
                    pi_s_a[i][j] = (w[i][j] * behav_policy[i][j]) / np.sum(w[i, :] * behav_policy[i, :])
        return pi_s_a
    
    def get_logger(self):
        return self.logger
# if __name__ == '__main__':
#     env_nocost = FrozenLakeEnv_nocost(ncol=8, nrow=8)
#     exp = ValueIteration(env_nocost, 0.01, env_nocost.gamma)
#     exp.value_iteration()
#     # plot_policy(np.array(exp.pi), env_nocost, "w")
#     env = FrozenLakeEnv(ncol=8, nrow=8)
#     offline_dataset = datacollect(env, expert_pi=exp.pi, percent=0.5)
#     behavior_policy = get_behave_policy_prob(env, offline_dataset)
#     offline_dataset = onehot_encode(env, offline_dataset)
#     Encode_type = 'one_hot' #'None' 'binary'
    
#     action = torch.tensor(offline_dataset['action'], dtype=torch.float32)
#     observation = torch.tensor(offline_dataset['observation'], dtype=torch.float32)
#     new_observation = torch.tensor(offline_dataset['new_observation'], dtype=torch.float32)
#     cost = torch.tensor(offline_dataset['cost'], dtype=torch.float32).view(-1, 1)
#     reward = torch.tensor(offline_dataset['reward'], dtype=torch.float32).view(-1, 1)
#     goal = torch.tensor(offline_dataset['goal'], dtype=torch.float32).view(-1, 1)
#     hole = torch.tensor(offline_dataset['hole'], dtype=torch.float32).view(-1, 1)
#     is_init = torch.tensor(offline_dataset['is_init'], dtype=torch.float32).view(-1, 1)
    
#     dataset = torch.utils.data.TensorDataset(observation, action, reward, cost, new_observation, goal, hole, is_init)
#     w_s_a = Coptidice_trainer(env, dataset, Encode_type, behavior_policy)
#     plot_policy((w_s_a, behavior_policy), env, 'w', random=False)
#     importance_test(env, w_s_a, 100, behavior_policy)