import numpy as np
import torch
from torch import nn
from torch.autograd import Function, Variable


class NearestEmbedFunc(Function):
    """
    Input:
    ------
    x - (batch_size, emb_dim, *)
        Last dimensions may be arbitrary
    emb - (emb_dim, num_emb)
    """
    @staticmethod
    def forward(ctx, input, emb, cnt, training, masked_indices):
        if input.size(1) != emb.size(0):
            raise RuntimeError('invalid argument: input.size(1) ({}) must be equal to emb.size(0) ({})'.
                               format(input.size(1), emb.size(0)))

        # save sizes for backward
        ctx.batch_size = input.size(0)
        ctx.num_latents = int(np.prod(np.array(input.size()[2:])))
        ctx.emb_dim = emb.size(0)
        ctx.num_emb = emb.size(1)
        ctx.input_type = type(input)
        ctx.dims = list(range(len(input.size())))

        # expand so it broadcastable
        x_expanded = input.unsqueeze(-1)
        num_arbitrary_dims = len(ctx.dims) - 2
        if num_arbitrary_dims:
            emb_expanded = emb.view(emb.shape[0], *([1] * num_arbitrary_dims), emb.shape[1])
        else:
            emb_expanded = emb

        # find nearest neighbors
        dist = torch.norm(x_expanded - emb_expanded, 2, 1)

        if masked_indices is not None:
            masked_indices = masked_indices.view(1, 1, 1, -1).expand_as(dist)
            dist = dist.masked_fill(masked_indices, 1e9) 

        '''
        # dropout
        if training:
            # mask = torch.ByteTensor(dist.size(-1)).to(dist.device).bernoulli_(0.5).view(1, 1, 1, -1).expand_as(dist)
            # dist = dist.masked_fill(mask, 1e9)
            dist_  = dist + torch.rand_like(dist) #* dist.std()
        '''

        _, argmin = dist.min(-1)

        # count
        if training:
            cnt[argmin.flatten()] += 1

        shifted_shape = [input.shape[0], *list(input.shape[2:]) ,input.shape[1]]
        result = emb.t().index_select(0, argmin.view(-1)).view(shifted_shape).permute(0, ctx.dims[-1], *ctx.dims[1:-1])

        ctx.argmin = argmin
        return result.contiguous(), argmin

    @staticmethod
    def backward(ctx, grad_output, argmin=None):
        grad_input = grad_emb = None
        if ctx.needs_input_grad[0]:
            grad_input = grad_output

        if ctx.needs_input_grad[1]:
            latent_indices = torch.arange(ctx.num_emb).type_as(ctx.argmin)
            idx_choices = (ctx.argmin.view(-1, 1) == latent_indices.view(1, -1)).type_as(grad_output.data)
            n_idx_choice = idx_choices.sum(0)
            n_idx_choice[n_idx_choice == 0] = 1
            idx_avg_choices = idx_choices / n_idx_choice
            grad_output = grad_output.permute(0, *ctx.dims[2:], 1).contiguous()
            grad_output = grad_output.view(ctx.batch_size * ctx.num_latents, ctx.emb_dim)
            grad_emb = Variable(torch.sum(
                grad_output.data.view(-1, ctx.emb_dim, 1) * idx_avg_choices.view(-1, 1, ctx.num_emb), 0))
        return grad_input, grad_emb, None, None, None


def nearest_embed(x, emb, cnt, training, masked_indices):
    return NearestEmbedFunc().apply(x, emb, cnt, training, masked_indices)


class NearestEmbed(nn.Module):
    def __init__(self, num_embeddings, embeddings_dim):
        super(NearestEmbed, self).__init__()
        self.weight = nn.Parameter(torch.rand(embeddings_dim, num_embeddings))
        self.count = torch.cuda.LongTensor(num_embeddings).fill_(0)

    def forward(self, x, training, masked_indices, weight_sg=False):
        """Input:
        ---------
        x - (batch_size, emb_size, *)
        """
        return nearest_embed(x, self.weight.detach() if weight_sg else self.weight, self.count, training, masked_indices)
