import torch
import torch.nn.functional as F


class NCMClassifier(torch.nn.Module):
    def __init__(self, gpu,  momentum=0.9):
        super().__init__()
        self.gpu = gpu
        self.momentum = momentum
        self.prototype_dict = {'prototypes': torch.tensor([]).cuda(gpu), 'labels': torch.tensor([]).cuda(gpu)}

    def add_class(self, features, labels):
        with torch.no_grad():
            new_labels = torch.tensor([label for label in labels if label not in self.prototype_dict['labels']]).cuda(self.gpu)
            new_prototypes = torch.stack([torch.mean(features[labels == label], dim=0, keepdim=True) for label in new_labels], dim=0).cuda(self.gpu)
            new_prototypes = torch.squeeze(new_prototypes)
            self.prototype_dict['labels'] = torch.cat([self.prototype_dict['labels'], new_labels], dim=0).long()
            self.prototype_dict['prototypes'] = torch.cat([self.prototype_dict['prototypes'], new_prototypes], dim=0)
            new_prototypes.cpu()
            new_labels.cpu()
            del new_prototypes
            del new_labels

    def predict(self, features):
        distances = torch.cdist(features, self.prototype_dict['prototypes']).cuda(self.gpu)
        # distances = torch.index_select(distances, dim=1, index=self.prototype_dict['labels'])
        probs = torch.softmax(-distances, dim=1)
        outputs = torch.argmax(probs, dim=1)
        print(distances, probs, outputs)
        # update the prototypes
        for i in range(len(features)):
            self.prototype_dict['prototypes'][outputs[i]] = self.momentum * self.prototype_dict['prototypes'][outputs[i]] + (1-self.momentum) * features[i]
        distances.cpu()
        del distances
        return probs, outputs

    def forward(self, features, labels):

        if labels is not None:
            self.add_class(features, labels)
        return self.predict(features)

