# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import torch
import torch.nn as nn
import torchvision.models as models
import sys
sys.path.append('../')
import utils
from models.backbones.resnet_mlp import resnet50_mlp1, resnet50_mlp2, resnet50_mlp3, resnet50_mlp4, resnet50_mlp5, resnet50_mlp6, resnet50_mlp10, resnet50_mlp100, resnet50_mlp1000
from models.backbones.resnet_mlp_norelu import resnet50_mlp1_norelu, resnet50_mlp2_norelu, resnet50_mlp3_norelu, resnet50_mlp4_norelu, resnet50_mlp5_norelu, resnet50_mlp6_norelu, resnet50_mlp10_norelu, resnet50_mlp20_norelu, resnet50_mlp100_norelu, resnet50_mlp256_norelu, resnet50_mlp512_norelu, resnet50_mlp1000_norelu, resnet50_mlp2000_norelu, resnet50_mlp2048_norelu
from models.backbones.resnet_mlp_nobn import resnet50_mlp2_nobn
from models.backbones.resnet_mlp_norelu_nobias import resnet50_mlp10_norelu_nobias
from models.backbones.resnet_linear import resnet50_linear1000
from models.backbones import resnet18_cifar_variant1
from models.backbones.cifar_resnet_1_mlp_norelu import resnet18_cifar_variant1_mlp1000_norelu, resnet18_cifar_variant1_mlp512_norelu, resnet18_cifar_variant1_mlp256_norelu, resnet18_cifar_variant1_mlp128_norelu, resnet18_cifar_variant1_mlp64_norelu, resnet18_cifar_variant1_mlp32_norelu, resnet18_cifar_variant1_mlp16_norelu, resnet18_cifar_variant1_mlp8_norelu, resnet18_cifar_variant1_mlp4_norelu, resnet18_cifar_variant1_mlp2_norelu, resnet50_cifar_variant1_mlp8_norelu, resnet50_cifar_variant1_mlp512_norelu
from models.backbones.resnet_mlp_norelu_3layer import resnet50_mlp2048_norelu_3layer
from models.backbones.resnet_mlp_norelu_4layer import resnet50_mlp4096_norelu_4layer

def get_model(out_dim, is_moco=False, arch='resnet50', print_prog=True, **kwargs):
    '''
    if is_moco:
        The Moco network is initialized and returned. kwargs are use_mlp_head
        (whether to use the MLP from SimCLR/Moco V2), distributed (some Moco member
        functions need to know whether the model is distributed), and model_init
        (state dict of weights with which to initialize the Moco network).
    else:
        A Resnet encoder (a Resnet50 with the final fc layer stripped) is returned, along with a
        (possibly transformed) fc layer. The architecture is:
            [encoder: in_dim -> 2048, fc: 2048 -> out_dim].
        If fc_transform is not supplied, the unmodified Resnet is returned. Otherwise:
            identity: returns solely the encoder, with a "dummy" fc layer that does nothing
            mlp: returns the encoder with an MLP head (almost like a Moco network,
                but with just one encoder rather than query/key encoders). The architecture is:
                [encoder: in_dim -> 2048, linear: 2048 -> 2048, ReLU, fc: 2048 -> out_dim]
            mlp_plus_linear: returns the encoder, MLP head, and final linear layer. Useful for training
                all encoder + MLP weights to desired initializations. The architecture is:
                [encoder: in_dim -> 2048, linear: 2048 -> 2048, ReLU, fc1: 2048 -> mlp_dim, ReLU,
                    fc2: mlp_dim -> out_dim]
                Note: mlp_dim *should* be > out_dim
    out_dim is the dimensionality of the network's output, and should correspond to the
    number of classes in the current dataset.
    '''

    if print_prog:
        print("=> creating model '{}' ({} moco)".format(arch, 'is' if is_moco else 'not'))
    try:
        base_model = models.__dict__[arch]
    except:
        #base_model = eval(f"{arch}()")
        base_model = globals()[arch]

    if is_moco:
        model = MoCo(base_model, **kwargs)
    else:
        model = base_model(num_classes=out_dim)
        rep_dim = model.fc.weight.shape[1] # should be 2048
        fc_transform = kwargs.get('fc_transform', None)
        if fc_transform  == 'identity':
            model.fc = nn.Identity()
        elif fc_transform == 'mlp':
            model.fc = nn.Sequential(nn.Linear(rep_dim, rep_dim), nn.ReLU(), nn.Linear(rep_dim, out_dim))
        elif fc_transform == 'mlp_plus_linear':
            mlp_dim = kwargs.get('mlp_dim', rep_dim)
            model.fc = nn.Sequential(
                nn.Linear(rep_dim, rep_dim), nn.ReLU(),
                nn.Linear(rep_dim, mlp_dim), nn.ReLU(),
                nn.Linear(mlp_dim, out_dim))
    if print_prog:
        print(model)
    return model


def load_checkpoint(model, state_dict, fname, load_pretrained_head=False, args=None, nomlp=False):
    '''
    model should be a plain resnet50 (optionally with replaced fc)
    state_dict is from moco
    '''
    print("=> loading checkpoint '{}'".format(fname))
    state_dict = utils.fix_dataparallel_keys(state_dict)

    # Rename moco pre-trained keys
    for k in list(state_dict.keys()):
        # retain only encoder_q up to before the embedding layer. However,
        # if load pretrained head is set, then also retain the fc weights.
        if nomlp:
            if k.startswith('backbone.') and (load_pretrained_head or (
                (not load_pretrained_head and not k.startswith('backbone.fc')))) and \
                    (load_pretrained_head or (
                            (not load_pretrained_head and not k.startswith('backbone.proj_resnet_layer1')))) and \
                    (load_pretrained_head or (
                            (not load_pretrained_head and not k.startswith('backbone.proj_resnet_layer2')))):
                # remove prefix
                state_dict[k[len('backbone.'):]] = state_dict[k]
        else:
            if k.startswith('backbone.') and \
                    (load_pretrained_head or (not load_pretrained_head and not k.startswith('backbone.fc'))):
                # remove prefix
                state_dict[k[len('backbone.'):]] = state_dict[k]
        # delete renamed or unused keys
        del state_dict[k]
    
    if args is not None:
        args.start_epoch = 0
    if load_pretrained_head:
        model.load_state_dict(state_dict, strict=True) # shouldn't have missing keys
    else:
        msg = model.load_state_dict(state_dict, strict=False)
        if not isinstance(model.fc, nn.Identity):
            assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}
    print("=> loaded pre-trained model '{}'".format(fname))


class MeanTeacher(nn.Module):
    def __init__(self, model, momentum=0.999):
        super(MeanTeacher, self).__init__()
        self.model = model
        self.m = momentum

    def forward(self, x):
        return self.model.forward(x)

    def update(self, student):
        for param_t, param_s in zip(self.model.parameters(), student.parameters()):
            param_t.data.mul_(self.m).add_(1 - self.m, param_s.data)


class MoCo(nn.Module):
    """
    Build a MoCo model with: a query encoder, a key encoder, and a queue
    https://arxiv.org/abs/1911.05722
    """
    def __init__(self, base_encoder, dim=128, K=65536, m=0.999, T=0.07, mlp=False,
            distributed=True, model_init=None):
        """
        dim: feature dimension (default: 128)
        K: queue size; number of negative keys (default: 65536)
        m: moco momentum of updating key encoder (default: 0.999)
        T: softmax temperature (default: 0.07)
        """
        super(MoCo, self).__init__()

        self.K = K
        self.m = m
        self.T = T

        # create the encoders
        # num_classes is the output fc dimension
        self.encoder_q = base_encoder(num_classes=dim)
        self.encoder_k = base_encoder(num_classes=dim)
        self.distributed = distributed

        ## Uncomment if we only want to load encoder weights
        # if model_init is not None:
        #     del model_init['fc.weight']
        #     del model_init['fc.bias']
        #     self.encoder_q.load_state_dict(model_init, strict=False)

        if mlp:  # hack: brute-force replacement
            dim_mlp = self.encoder_q.fc.weight.shape[1]
            self.encoder_q.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_q.fc)
            self.encoder_k.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_k.fc)
        
        ## Copy weights over. Uncomment if we want to load weights for both encoder + mlp
        if model_init is not None:
            self.encoder_q.load_state_dict(model_init, strict=True)

        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            param_k.data.copy_(param_q.data)  # initialize
            param_k.requires_grad = False  # not update by gradient

        # create the queue
        self.register_buffer("queue", torch.randn(dim, K))
        self.queue = nn.functional.normalize(self.queue, dim=0)

        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

    @torch.no_grad()
    def _momentum_update_key_encoder(self):
        """
        Momentum update of the key encoder
        """
        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)

    @torch.no_grad()
    def _dequeue_and_enqueue(self, keys):
        # gather keys before updating queue
        if self.distributed: keys = concat_all_gather(keys)

        batch_size = keys.shape[0]

        ptr = int(self.queue_ptr)
        assert self.K % batch_size == 0  # for simplicity

        # replace the keys at ptr (dequeue and enqueue)
        self.queue[:, ptr:ptr + batch_size] = keys.T
        ptr = (ptr + batch_size) % self.K  # move pointer

        self.queue_ptr[0] = ptr

    @torch.no_grad()
    def _batch_shuffle_ddp(self, x):
        """
        Batch shuffle, for making use of BatchNorm.
        *** Only support DistributedDataParallel (DDP) model. ***
        """
        # gather from all gpus
        batch_size_this = x.shape[0]
        x_gather = concat_all_gather(x)
        batch_size_all = x_gather.shape[0]

        num_gpus = batch_size_all // batch_size_this

        # random shuffle index
        idx_shuffle = torch.randperm(batch_size_all).cuda()

        # broadcast to all gpus
        torch.distributed.broadcast(idx_shuffle, src=0)

        # index for restoring
        idx_unshuffle = torch.argsort(idx_shuffle)

        # shuffled index for this gpu
        gpu_idx = torch.distributed.get_rank()
        idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx]

        return x_gather[idx_this], idx_unshuffle

    @torch.no_grad()
    def _batch_unshuffle_ddp(self, x, idx_unshuffle):
        """
        Undo batch shuffle.
        *** Only support DistributedDataParallel (DDP) model. ***
        """
        # gather from all gpus
        batch_size_this = x.shape[0]
        x_gather = concat_all_gather(x)
        batch_size_all = x_gather.shape[0]

        num_gpus = batch_size_all // batch_size_this

        # restored index for this gpu
        gpu_idx = torch.distributed.get_rank()
        idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx]

        return x_gather[idx_this]

    def forward(self, im_q, im_k):
        """
        Input:
            im_q: a batch of query images
            im_k: a batch of key images
        Output:
            logits, targets
        """

        # compute query features
        q = self.encoder_q(im_q)  # queries: NxC
        q = nn.functional.normalize(q, dim=1)

        # compute key features
        with torch.no_grad():  # no gradient to keys
            self._momentum_update_key_encoder()  # update the key encoder

            # shuffle for making use of BN
            if self.distributed:
                im_k, idx_unshuffle = self._batch_shuffle_ddp(im_k)

            k = self.encoder_k(im_k)  # keys: NxC
            k = nn.functional.normalize(k, dim=1)

            # undo shuffle
            if self.distributed:
                k = self._batch_unshuffle_ddp(k, idx_unshuffle)

        # compute logits
        # Einstein sum is more intuitive
        # positive logits: Nx1
        l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
        # negative logits: NxK
        l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])

        # logits: Nx(1+K)
        logits = torch.cat([l_pos, l_neg], dim=1)

        # apply temperature
        logits /= self.T

        # labels: positive key indicators
        labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()

        # dequeue and enqueue
        self._dequeue_and_enqueue(k)

        return logits, labels

# utils
@torch.no_grad()
def concat_all_gather(tensor):
    """
    Performs all_gather operation on the provided tensors.
    *** Warning ***: torch.distributed.all_gather has no gradient.
    """
    tensors_gather = [torch.ones_like(tensor)
        for _ in range(torch.distributed.get_world_size())]
    torch.distributed.all_gather(tensors_gather, tensor, async_op=False)

    output = torch.cat(tensors_gather, dim=0)
    return output
