import torch
import numpy as np
from torch.distributions.gamma import Gamma

'''
a_u = 100000
b_u = 0

a_minus = 1
b_minus = 0

a_plus = 1
b_plus = 0
'''

a_u = 1
b_u = 1
a_minus = 10
b_minus = 1
a_plus = 5
b_plus = 1

class NTXentLoss(torch.nn.Module):

    def __init__(self, device, batch_size, temperature, use_cosine_similarity):
        super(NTXentLoss, self).__init__()
        self.batch_size = batch_size
        self.temperature = temperature
        self.device = device
        self.softmax = torch.nn.Softmax(dim=-1)
        self.mask_samples_from_same_repr = self._get_correlated_mask().type(torch.bool)
        self.similarity_function = self._get_similarity_function(use_cosine_similarity)
        self.criterion = torch.nn.CrossEntropyLoss(reduction="sum")
        self.w=torch.ones(2*batch_size,2*batch_size-1)
        

    def _get_similarity_function(self, use_cosine_similarity):
        if use_cosine_similarity:
            self._cosine_similarity = torch.nn.CosineSimilarity(dim=-1)
            return self._cosine_simililarity
        else:
            return self._dot_simililarity

    def _get_correlated_mask(self):
        diag = np.eye(2 * self.batch_size)
        l1 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=-self.batch_size)
        l2 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=self.batch_size)
        mask = torch.from_numpy((diag + l1 + l2))
        mask = (1 - mask).type(torch.bool)
        return mask.to(self.device)

    @staticmethod
    def _dot_simililarity(x, y):
        v = torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims=2)
        # x shape: (N, 1, C)
        # y shape: (1, C, 2N)
        # v shape: (N, 2N)
        return v

    def _cosine_simililarity(self, x, y):
        # x shape: (N, 1, C)
        # y shape: (1, 2N, C)
        # v shape: (N, 2N)
        v = self._cosine_similarity(x.unsqueeze(1), y.unsqueeze(0))
        return v
    
    def sample_u(self,w_matrix, sim_matrix):
        full_mat = w_matrix * sim_matrix
        #print("w_mat:", w_matrix)
        rate_param = b_u + full_mat.sum(dim=1)
       
        u_dist = Gamma(torch.tensor(a_u).float().to(self.device),\
                rate_param.float())
       
        u=u_dist.sample()
        #print("u:",u)
        return u
    
    def sample_w(self,U, s_matrix):
        #import ipdb;ipdb.set_trace()
        BS = s_matrix.shape[0]
        mask = torch.cat([torch.ones(1), torch.zeros(BS-2)]).repeat(BS, 1)

        s_plus = s_matrix.masked_select(mask.bool().to(self.device))
        s_minus = s_matrix.masked_select(~mask.bool().to(self.device))
        w_plus_dist = Gamma(torch.tensor(1+a_plus).float().to(self.device),\
                U*s_plus + b_plus)
        U = U.repeat_interleave(int(BS-2)) ###### CHeckifyouneedthis
        w_minus_dist = Gamma(torch.tensor(a_minus).float().to(self.device),\
                U*s_minus + b_minus)
        w_plus = w_plus_dist.sample().reshape(BS,1)
        #print("w_plus:",w_plus)
        w_minus = w_minus_dist.sample().reshape(BS,BS-2)
        #print("w_minus:",w_minus)
 
        result = torch.cat([w_plus, w_minus], dim=1)

        return result

    def forward(self, zis, zjs):
        #import ipdb;ipdb.set_trace()
        representations = torch.cat([zjs, zis], dim=0)

        #similarity_matrix = self.similarity_function(representations, representations)
        similarity_matrix = torch.exp(torch.matmul(representations, representations.T))

        # filter out the scores from the positive samples
        l_pos = torch.diag(similarity_matrix, self.batch_size)
        r_pos = torch.diag(similarity_matrix, -self.batch_size)
        positives = torch.cat([l_pos, r_pos]).view(2 * self.batch_size, 1)

        negatives = similarity_matrix[self.mask_samples_from_same_repr].view(2 * self.batch_size, -1)

        logits = torch.cat((positives, negatives), dim=1)
        logits /= self.temperature
        
        weights = torch.ones_like(logits).detach().to(self.device)
        for _ in range(2):
            U = self.sample_u(weights, logits)
            weights = self.sample_w(U, logits)
        self.w=weights
        weighted_logits = logits * weights
        log_probs = torch.nn.functional.log_softmax(weighted_logits, dim=1)
        mask = torch.cat([torch.ones(1), torch.zeros(2*self.batch_size-2)]).\
                repeat(2*self.batch_size, 1).to(self.device)
        pos_sim = log_probs.masked_select(mask.bool())
        #neg_sim = weighted_logits.masked_select(~mask.bool()).reshape(2*self.batch_size, -1)
        #import ipdb;ipdb.set_trace()
        loss = -pos_sim.mean()

        #logits = torch.log(self.w) + logits
        #logits /= self.temperature
        #labels = torch.zeros(2 * self.batch_size).to(self.device).long()
        
        
        
        #loss = self.criterion(logits, labels)
        #print("loss",loss)

        #return loss / (2 * self.batch_size)
        return loss
