import torch 
import torch.distributions as D 
import torch.nn as nn 
import numpy as np 
import tqdm 
import matplotlib.pyplot as plt 
import stan 

from var_red_gfn.utils import Environment 

class LogReward(nn.Module): 

    def __init__(self, mu, sigma): 
        super(LogReward, self).__init__() 
        self.mu = mu 
        self.device = self.mu.device 
        self.sigma = torch.ones_like(self.mu) * sigma  
        self.gaussian_dist = D.MixtureSameFamily(
            D.Categorical(torch.ones((self.mu.shape[0],), device=self.device)), 
            D.Independent(D.Normal(self.mu, self.sigma), 1) 
        ) 

    @torch.no_grad() 
    def forward(self, batch_state): 
        return self.gaussian_dist.log_prob(batch_state.state) 

    @torch.no_grad() 
    def sample(self, num_samples): 
        samples = self.gaussian_dist.sample((num_samples,)) 
        return samples, self.gaussian_dist.log_prob(samples) 
    
class LogRewardBanana(nn.Module): 

    def __init__(self, mu=[0., 0.], sigma=[[1., .9], [.9, 1.]], device='cpu', stan_filename='banana.stan'): 
        super(LogRewardBanana, self).__init__() 
        self.mu = torch.tensor(mu, device=device) if not torch.is_tensor(mu) else mu  
        self.sigma = torch.tensor(sigma, device=device) if not torch.is_tensor(sigma) else sigma 
        self.device = device  
        self.gaussian_dist = D.MultivariateNormal(
            loc=self.mu, 
            covariance_matrix=self.sigma 
        )

        with open(stan_filename, 'r') as stream: 
            stan_code = stream.read() 
            self.posterior = stan.build(
                stan_code, data={'mu': self.mu.cpu().tolist(), 'sigma': self.sigma.cpu().tolist()} 
            )

    @torch.no_grad() 
    def forward(self, batch_state): 
        s = batch_state.state.clone() 
        s[:, 1] = s[:, 1] + s[:, 0] * s[:, 0] + 1 
        return self.gaussian_dist.log_prob(s)  
    
    @torch.no_grad() 
    def sample(self, num_samples, num_chains=4): 
        fit = self.posterior.sample(num_chains=num_chains, num_samples=num_samples//num_chains) 
        fit = fit.to_frame()

        samples = torch.tensor(fit[['x.1', 'x.2']].values, device=self.device)  
        transformed_samples = torch.tensor(fit[['y.1', 'y.2']].values, device=self.device) 
        
        # Compute the log-prob 
        return samples, self.gaussian_dist.log_prob(transformed_samples) 
     
class GaussianMixture(Environment): 

    def __init__(self, dim, batch_size, log_reward, device='cpu'): 
        self.dim = dim 
        super(GaussianMixture, self).__init__(batch_size, self.dim + 1, log_reward, device) 
      
        self.state = torch.zeros((self.batch_size, self.dim), device=self.device) 
        self.curr_idx = 0 

    @torch.no_grad() 
    def apply(self, actions): 
        state = self.state.clone() 
        state[:, self.curr_idx] = actions
        self.state = state  
        self.curr_idx += 1 
        self.stopped[:] = (self.curr_idx == self.dim) 

    @torch.no_grad() 
    def backward(self, actions=None): 
        self.curr_idx -= 1 
        curr_state = self.state[:, self.curr_idx].clone()  
        self.state[:, self.curr_idx] = 0. 
        self.stopped[:] = 0 
        self.is_initial = (self.state == 0).all(dim=1) 
        return curr_state 

    @staticmethod 
    @torch.no_grad() 
    def plot_samples(samples, log_reward):         
        samples = samples.cpu() 
        plt.scatter(samples[:, 0], samples[:, 1]) 
        lim = (log_reward.mu.min().cpu() - .5, log_reward.mu.max().cpu() + .5) 
        plt.xlim(*lim) 
        plt.ylim(*lim) 
        plt.savefig('gmms.png') 

    @staticmethod 
    def logsumexp(p, q): 
        m = max(p.max(), q.max()) 
        return torch.log((p - m).exp() + (q - m).exp()) + m - np.log(2) 

    @staticmethod 
    @torch.no_grad() 
    def estimate_js(gfn, create_env, num_batches, log_reward, 
                    use_progress_bar=False, plot=False): 
        # Sample from gflownet 
        samples = list() 
        log_p = list() # Target distribution 
        log_q = list() # Proposal distribution 

        for _ in tqdm.trange(num_batches, disable=not use_progress_bar): 
            env = create_env() 
            traj_stats, F_traj, last_idx = gfn._sample_traj(env)  
            samples.append(
                env.state 
            )
            log_p.append(
                F_traj[env.batch_ids, last_idx] 
            )
            log_q.append(
                traj_stats[0].sum(dim=1) 
            )

            batch_size = env.batch_size 

        samples = torch.vstack(samples) 
        log_p = torch.hstack(log_p) 
        log_q = torch.hstack(log_q) 
        kl_q_m = (log_q - GaussianMixture.logsumexp(log_p, log_q)).mean()  

        if plot: 
            plt.scatter(samples[:, 0].cpu(), samples[:, 1].cpu(), label='learned') 
        
        # Sample from target 
        log_q = list() 

        samples, log_p = log_reward.sample(batch_size * num_batches) 
        for idx in tqdm.trange(num_batches, disable=not use_progress_bar):
            batch_samples = samples[idx*batch_size:(idx+1)*batch_size].clone() 
            env = gfn.sample(create_env())  
            env.state = batch_samples 
            log_q.append(gfn.marginal_prob(env)) 

        log_q = torch.hstack(log_q)
        kl_p_m = (log_p - GaussianMixture.logsumexp(log_p, log_q)).mean()  

        if plot: 
            plt.scatter(samples[:, 0].cpu(), samples[:, 1].cpu(), label='targeted', alpha=1e-1) 
            plt.legend() 
            plt.savefig('gmms.png') 
         
        # Compute divergences 
        js = (kl_q_m + kl_p_m) / 2 

        # Return the JS divergence 
        return js 