import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torchvision.models as models
import torch.optim as optim
from functools import reduce
from models.ResUNet import ResUNet, ResUNetPre, ResUNetPrePre, ResnetUpSample
from models.StackedHourglass import *

vgg_layers = {'pool4': 24, 'pool5': 31}
net_stride = {'resnet50_up_vc': 8, 'vgg_pool4_up_vc': 8, 'vgg_pool4_vc': 16, 'vgg1_pool4_vc': 16, 'vgg_pool4_up': 8,
              'vgg_pool4': 16, 'vgg_pool5': 32, 'resnet50': 32, 'resnext50': 32, 'resnet50_pre': 16,
              'resnet50_prepre': 8, 'resunet': 2, 'resunetpre': 8, 'resunetprepre': 16, 'hg': 4, 'resnetupsample': 16}
net_out_dimension = {'resnet50_up_vc': 512, 'vgg_pool4_up_vc': 512, 'vgg_pool4_vc': 512, 'vgg1_pool4_vc': 512,
                     'vgg_pool4_up': 512, 'vgg_pool4': 512, 'vgg_pool5': 512, 'resnet50': 2048, 'resnext50': 2048,
                     'resnet50_pre': 1024, 'resnet50_prepre': 512, 'resunet': 128, 'resunetpre': 256,
                     'resunetprepre': 512, 'hg': 256, 'resnetupsample': 2048}


def hg(num_stacks, num_blocks):
    model = HourglassNet(Bottleneck, num_stacks=num_stacks, num_blocks=num_blocks)
    return model


def resunet(pretrain):
    net = ResUNet(pretrained=pretrain)
    return net


def resunetpre(pretrain):
    net = ResUNetPre(pretrained=pretrain)
    return net


def resunetprepre(pretrain):
    net = ResUNetPrePre(pretrained=pretrain)
    return net


def resnetupsample(pretrain):
    net = ResnetUpSample(pretrained=pretrain)
    return net


class VCAct(nn.Module):
    def __init__(self, activate_funcs, vc_centers_dir=None):
        super(VCAct, self).__init__()
        self.normalizer = lambda feature: nn.functional.normalize(
            feature, p=2, dim=1, eps=1e-12)

        # if not vc_centers_dir:
        #     vc_centers_dir = weights_path + 'vc_centers_pool4.npy'
        vc_centers = torch.from_numpy(np.load(vc_centers_dir, allow_pickle=True).astype(np.float32))
        print(vc_centers_dir)
        self.conv = nn.Conv2d(in_channels=vc_centers.shape[1], out_channels=vc_centers.shape[0], kernel_size=1,
                              bias=False)
        self.acts = activate_funcs

        self.init_weight(vc_centers)

    def init_weight(self, vc_centers):
        with torch.set_grad_enabled(False):
            for i in range(vc_centers.shape[0]):
                for j in range(vc_centers.shape[1]):
                    self.conv.weight[i, j, 0, 0] = vc_centers[i, j]

    def forward(self, x):
        x = self.normalizer(x)
        x = self.conv(x)
        return x


class ExpLayer(nn.Module):
    def __init__(self, vMF_kappa, binary=False, thr=0.55):
        super(ExpLayer, self).__init__()
        self.vMF_kappa = vMF_kappa  # nn.Parameter(torch.Tensor([vMF_kappa])).cuda()
        self.binary = binary
        self.thr = thr

    def forward(self, x):
        if self.binary:
            x = torch.exp(self.vMF_kappa * x) * (x > self.thr).type(x.dtype).to(x.device)
        else:
            x = torch.exp(self.vMF_kappa * x)
        return x


def vgg16(layer='pool4', additional=None):
    net = models.vgg16(pretrained=True)
    model = nn.Sequential()
    features = nn.Sequential()
    for i in range(0, vgg_layers[layer]):
        features.add_module('{}'.format(i), net.features[i])
    model.add_module('features', features)
    if additional is not None:
        model.add_module('additional', additional)
    return model


def vgg16_old(init_path='./vgg.pth', layer='pool4'):
    net = models.vgg16(pretrained=False)
    model = nn.Sequential()
    features = nn.Sequential()
    for i in range(0, vgg_layers[layer]):
        features.add_module('{}'.format(i), net.features[i])
    model.add_module('features', features)

    model_dict = model.state_dict()
    checkpoint = torch.load(init_path, map_location='cuda:{}'.format(0))
    model.eval()

    update_dict = {k: checkpoint[k] for k in checkpoint if k in model_dict}
    model_dict.update(update_dict)
    model.load_state_dict(model_dict)
    return model


def resnet50(pretrain, additional=None):
    net = models.resnet50(pretrained=pretrain)
    extractor = nn.Sequential()
    extractor.add_module('0', net.conv1)  # add_module(name, module)
    extractor.add_module('1', net.bn1)
    extractor.add_module('2', net.relu)
    extractor.add_module('3', net.maxpool)
    extractor.add_module('4', net.layer1)
    extractor.add_module('5', net.layer2)
    extractor.add_module('6', net.layer3)
    extractor.add_module('7', net.layer4)
    if additional is not None:
        extractor.add_module('8', additional)
    return extractor


def resnet50_pre(pretrain):
    net = models.resnet50(pretrained=pretrain)
    extractor = nn.Sequential()
    extractor.add_module('0', net.conv1)  # add_module(name, module)
    extractor.add_module('1', net.bn1)
    extractor.add_module('2', net.relu)
    extractor.add_module('3', net.maxpool)
    extractor.add_module('4', net.layer1)
    extractor.add_module('5', net.layer2)
    extractor.add_module('6', net.layer3)
    return extractor


def resnet50_prepre(pretrain):
    net = models.resnet50(pretrained=pretrain)
    extractor = nn.Sequential()
    extractor.add_module('0', net.conv1)  # add_module(name, module)
    extractor.add_module('1', net.bn1)
    extractor.add_module('2', net.relu)
    extractor.add_module('3', net.maxpool)
    extractor.add_module('4', net.layer1)
    extractor.add_module('5', net.layer2)
    return extractor


def resnext50(pretrain):
    net = models.resnext50_32x4d(pretrained=pretrain)
    extractor = nn.Sequential()
    extractor.add_module('0', net.conv1)
    extractor.add_module('1', net.bn1)
    extractor.add_module('2', net.relu)
    extractor.add_module('3', net.maxpool)
    extractor.add_module('4', net.layer1)
    extractor.add_module('5', net.layer2)
    extractor.add_module('6', net.layer3)
    extractor.add_module('7', net.layer4)
    return extractor


# keypoints = torch.tensor([[(36, 40), (90, 80)]])
# downsample_rate = 32
# original_img_size = torch.Size([224, 300])
# calculate which patch is. if (1,1) and line size = 9, return 1*9+1 = 10
def keypoints_to_pixel_index(keypoints, downsample_rate, original_img_size=(480, 640)):
    # line_size = 9
    line_size = original_img_size[1] // downsample_rate
    # round down, new coordinate (keypoints[:,:,0]//downsample_rate, keypoints[:, :, 1] // downsample_rate)
    return keypoints[:, :, 0] // downsample_rate * line_size + keypoints[:, :, 1] // downsample_rate


def get_noise_pixel_index(keypoints, max_size, n_samples, obj_mask=None):
    n = keypoints.shape[0]

    # remove the point in keypoints by set probability to 0 otherwise 1 -> mask [n, size] with 0 or 1
    mask = torch.ones((n, max_size), dtype=torch.float32).to(keypoints.device)
    mask = mask.scatter(1, keypoints.type(torch.long), 0.)
    if obj_mask is not None:
        mask *= obj_mask

    # generate the sample by the probabilities
    # print(torch.sum(mask, dim=1), )
    return torch.multinomial(mask, n_samples)


class GlobalLocalConverter(nn.Module):
    def __init__(self, local_size):
        super(GlobalLocalConverter, self).__init__()
        self.local_size = local_size
        self.padding = sum([[t - 1 - t // 2, t // 2] for t in local_size[::-1]], [])

    def forward(self, X):
        n, c, h, w = X.shape  # torch.Size([1, 2048, 8, 8])

        # N, C, H, W -> N, C, H + local_size0 - 1, W + local_size1 - 1
        X = F.pad(X, self.padding)

        # N, C, H + local_size0 - 1, W + local_size1 - 1 -> N, C * local_size0 * local_size1, H * W
        X = F.unfold(X, kernel_size=self.local_size)  # 3*3

        # N, C * local_size0 * local_size1, H * W -> N, C, local_size0, local_size1, H * W
        # X = X.view(n, c, *self.local_size, -1)

        # torch.Size([1, 18432, 64])
        # X:  N, C * local_size0 * local_size1, H * W
        return X


class MergeReduce(nn.Module):
    def __init__(self, reduce_method='mean'):
        super(MergeReduce, self).__init__()
        self.reduce_method = reduce_method
        self.local_size = -1

    def register_local_size(self, local_size):
        self.local_size = local_size[0] * local_size[1]
        ###self.local_size = reduce((lambda x, y: x * y), local_size)
        if self.reduce_method == 'mean':
            self.foo_test = torch.nn.AvgPool2d(local_size, stride=1, padding=local_size[0] // 2, )
        elif self.reduce_method == 'max':
            self.foo_test = torch.nn.MaxPool2d(local_size, stride=1, padding=local_size[0] // 2, )

    def forward(self, X):

        X = X.view(X.shape[0], -1, self.local_size, X.shape[2])
        if self.reduce_method == 'mean':
            return torch.mean(X, dim=2)
        elif self.reduce_method == 'max':
            return torch.max(X, dim=2)

    def forward_test(self, X):
        return self.foo_test(X)


def batched_index_select(t, dim, inds):
    dummy = inds.unsqueeze(2).expand(inds.size(0), inds.size(1), t.size(2))
    out = t.gather(dim, dummy)  # b * e * f
    return out


class NetE2E(nn.Module):
    def __init__(self, pretrain, net_type, local_size, output_dimension, reduce_function=None, n_noise_points=0,
                 num_stacks=8, num_blocks=1, noise_on_mask=True, **kwargs):
        # output_dimension = 128
        super(NetE2E, self).__init__()
        if net_type == 'vgg_pool4':
            self.net = vgg16('pool4')
        elif net_type == 'vgg_pool4_up':
            self.net = vgg16('pool4', additional=nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True))
        elif net_type == 'vgg_pool4_up_vc':
            self.net = vgg16('pool4', additional=nn.Sequential(
                nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
                VCAct([], 'models/dictionary_pool4.pickle'),
                ExpLayer(30, binary=kwargs['binary'], thr=0.55),
            ))
        elif net_type == 'vgg_pool4_vc':
            self.net = vgg16('pool4', additional=nn.Sequential(
                VCAct([], 'models/dictionary_pool4.pickle'),
                ExpLayer(30, binary=kwargs['binary'], thr=0.55),
            ))
        elif net_type == 'vgg1_pool4_vc':
            self.net = nn.Sequential(vgg16_old(init_path='models/vgg.pth', layer='pool4', ),
                                     VCAct([], 'models/vc_centers_pool4.npy'),
                                     ExpLayer(30, binary=kwargs['binary'], thr=0.55), )
        elif net_type == 'vgg_pool5':
            self.net = vgg16('pool5')
        elif net_type == 'resnet50':
            self.net = resnet50(pretrain)
        elif net_type == 'resnext50':
            self.net = resnext50(pretrain)
        elif net_type == 'resnet50_pre':
            self.net = resnet50_pre(pretrain)
        elif net_type == 'resnet50_prepre':
            self.net = resnet50_prepre(pretrain)
        elif net_type == 'resnet50_up_vc':
            self.net = resnet50(pretrain, additional=nn.Sequential(VCAct([], 'models/vc_centers_resnet.npy'),
                                                                   ExpLayer(30, binary=kwargs['binary'], thr=0.3),
                                                                   nn.Upsample(scale_factor=4, mode='bilinear',
                                                                               align_corners=True)
                                                                   ))
        elif net_type == 'resnet50_vc':
            self.net = resnet50(pretrain, nn.Sequential(VCAct([], 'models/vc_centers_resnet.npy'),
                                                        ExpLayer(30, binary=kwargs['binary'], thr=0.3),
                                                        # nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
                                                        ))
        elif net_type == 'resunet':
            self.net = resunet(pretrain)
        elif net_type == 'hg':
            self.net = hg(num_stacks, num_blocks)
        elif net_type == 'resunetpre':
            self.net = resunetpre(pretrain)
        elif net_type == 'resunetprepre':
            self.net = resunetprepre(pretrain)
        elif net_type == 'resnetupsample':
            self.net = resnetupsample(pretrain)

        self.size_number = local_size[0] * local_size[1]
        self.output_dimension = output_dimension
        ###size_number = reduce((lambda x, y: x * y), local_size)
        if reduce_function:
            reduce_function.register_local_size(local_size)
            self.size_number = 1

        self.reduce_function = reduce_function
        self.net_type = net_type
        self.net_stride = net_stride[net_type]
        self.converter = GlobalLocalConverter(local_size)
        self.noise_on_mask = noise_on_mask
        # 2048 *
        '''self.out_layer ? '''
        if self.output_dimension == -1:
            self.out_layer = None
        else:
            self.out_layer = nn.Linear(net_out_dimension[net_type] * self.size_number,
                                       self.output_dimension)  # output_dimension , net_out_dimension[net_type] * size_number

        self.n_noise_points = n_noise_points
        # self.norm_layer = lambda x: F.normalize(x, p=2, dim=1)

        # print(self.net)

    # forward
    def forward_test(self, X, return_features=False, do_normalize=True):
        m = self.net.forward(X)

        # not used
        if self.reduce_function:
            X = self.reduce_function.forward_test(m)
        else:
            X = m

        if self.out_layer is None:
            if do_normalize:
                return F.normalize(X, p=2, dim=1)
            else:
                return X
        if self.size_number == 1:
            X = torch.nn.functional.conv2d(X, self.out_layer.weight.unsqueeze(2).unsqueeze(3))
        elif self.size_number > 1:
            X = torch.nn.functional.conv2d(X, self.out_layer.weight.view(self.output_dimension,
                                                                         net_out_dimension[self.net_type],
                                                                         self.size_number).permute(2, 0, 1).reshape(
                self.size_number * self.output_dimension, net_out_dimension[self.net_type]).unsqueeze(2).unsqueeze(3))
        # print('X_new.shape', X.shape)
        # n, c, w, h
        # 1, 128, (w_original - 1) // 32 + 1, (h_original - 1) // 32 + 1
        # X = F.interpolate(X, scale_factor=2, mode='bilinear')
        if do_normalize:
            X = F.normalize(X, p=2, dim=1)

        if return_features:
            return X, m
        else:
            return X

    def forward_step0(self, X, do_normalize=True, return_features=False, **kwargs):
        # downsample_rate = 32
        # pre--X.shape torch.Size([1, 3, 256, 256])
        m = self.net.forward(X)

        if self.out_layer is not None:
            X = torch.nn.functional.conv2d(m, self.out_layer.weight.unsqueeze(2).unsqueeze(3))
        else:
            X = m

        if do_normalize:
            X = F.normalize(X, p=2, dim=1)

        if return_features:
            return X, m
        return X

    def forward_step1(self, X, keypoint_positions, img_shape, obj_mask=None, **kwargs):
        n = X.shape[0]

        # N, C * local_size0 * local_size1, H * W
        X = self.converter(X)

        keypoint_idx = keypoints_to_pixel_index(keypoints=keypoint_positions,
                                                downsample_rate=self.net_stride,
                                                original_img_size=img_shape).type(torch.long)
        # keypoint_idx tensor([[52, 44, 36, 36, 43, 51, 36, 20, 20, 11, 27, 28, 20, 20, 27, 27]])
        # Never use this reduce_function part.
        if self.reduce_function:
            X = self.reduce_function(X)
        if self.n_noise_points == 0:
            keypoint_all = keypoint_idx
        else:
            if obj_mask is not None:
                obj_mask = F.max_pool2d(obj_mask.unsqueeze(dim=1),
                                        kernel_size=self.net_stride,
                                        stride=self.net_stride,
                                        padding=(self.net_stride - 1) // 2)
                obj_mask = obj_mask.view(obj_mask.shape[0], -1)
                assert obj_mask.shape[1] == X.shape[2], 'mask_: ' + str(obj_mask.shape) + ' fearture_: ' + str(X.shape)
            if self.noise_on_mask:
                keypoint_noise = get_noise_pixel_index(keypoint_idx,
                                                       max_size=X.shape[2],
                                                       n_samples=self.n_noise_points,
                                                       obj_mask=obj_mask)
            else:
                keypoint_noise = get_noise_pixel_index(keypoint_idx,
                                                       max_size=X.shape[2],
                                                       n_samples=self.n_noise_points,
                                                       obj_mask=None)

            keypoint_all = torch.cat((keypoint_idx, keypoint_noise), dim=1)
            # keypoint_all tensor([[51, 43, 27, 35,  0, 51, 35, 19, 19, 11,  0, 27, 19, 19, 27, 35, 13, 17, 28, 33, 47, 45, 46, 21, 20, 42, 62,  8, 15, 10, 18,  5,  2, 16, 49, 34]])

        # n * c * k -> n * k * c
        # N, C * local_size0 * local_size1, H * W - >  #N, H * W, C * local_size0 * local_size1
        X = torch.transpose(X, 1, 2)

        # N, H * W, C * local_size0 * local_size1 -> N, keypoint_all, C * local_size0 * local_size1
        X = batched_index_select(X, dim=1, inds=keypoint_all)
        # torch.stack([X[i, :, keypoint_all[i]] for i in range(n)])

        if self.out_layer is None:
            X = X.view(n, -1, net_out_dimension[self.net_type])
        else:
            X = X.view(n, -1, self.out_layer.weight.shape[0])
        return X
    
    def forward(self, *args, mode=-1, **kwargs):
        if mode == -1:
            if 'X' in kwargs.keys():
                img_shape = kwargs['X'].shape[2::]
            else:
                img_shape = args[0].shape[2::]
            X = self.forward_step0(*args, **kwargs)
            if len(args):
                args = args[1::]
            if 'X' in kwargs.keys():
                del kwargs['X']
            return self.forward_step1(*args, X=X, img_shape=img_shape, **kwargs)
        elif mode == 0:
            return self.forward_step0(*args, **kwargs)
        elif mode == 1:
            return self.forward_step1(*args, **kwargs)
            

    # kernal = 12 * 4, 128, 1, 1
    # 128, 12 * group_size -> 1, 12 * group_size, w, h
    # def forward(self, X, keypoint_positions, obj_mask=None, return_map=False, return_features=False, do_normalize=True):
    #     # X=torch.ones(1, 3, 224, 300), kps = torch.tensor([[(36, 40), (90, 80)]])
    #     # n images, k keypoints and 2 states.
    #     # Keypoint input -> n * k * 2 (k keypoints for n images) (must be position on original image)
    # 
    #     n = X.shape[0]  # n = 1
    #     img_shape = X.shape[2::]
    # 
    #     # downsample_rate = 32
    #     # pre--X.shape torch.Size([1, 3, 256, 256])
    #     m = self.net.forward(X)
    # 
    #     # N, C * local_size0 * local_size1, H * W
    #     X = self.converter(m)
    # 
    #     keypoint_idx = keypoints_to_pixel_index(keypoints=keypoint_positions,
    #                                             downsample_rate=self.net_stride,
    #                                             original_img_size=img_shape).type(torch.long)
    #     # keypoint_idx tensor([[52, 44, 36, 36, 43, 51, 36, 20, 20, 11, 27, 28, 20, 20, 27, 27]])
    #     # Never use this reduce_function part.
    #     if self.reduce_function:
    #         X = self.reduce_function(X)
    #     if self.n_noise_points == 0:
    #         keypoint_all = keypoint_idx
    #     else:
    #         if obj_mask is not None:
    #             obj_mask = F.max_pool2d(obj_mask.unsqueeze(dim=1),
    #                                     kernel_size=self.net_stride,
    #                                     stride=self.net_stride,
    #                                     padding=(self.net_stride - 1) // 2)
    #             obj_mask = obj_mask.view(obj_mask.shape[0], -1)
    #             assert obj_mask.shape[1] == X.shape[2], 'mask_: ' + str(obj_mask.shape) + ' fearture_: ' + str(X.shape)
    #         if self.noise_on_mask:
    #             keypoint_noise = get_noise_pixel_index(keypoint_idx,
    #                                                    max_size=X.shape[2],
    #                                                    n_samples=self.n_noise_points,
    #                                                    obj_mask=obj_mask)
    #         else:
    #             keypoint_noise = get_noise_pixel_index(keypoint_idx,
    #                                                    max_size=X.shape[2],
    #                                                    n_samples=self.n_noise_points,
    #                                                    obj_mask=None)
    # 
    #         keypoint_all = torch.cat((keypoint_idx, keypoint_noise), dim=1)
    #         # keypoint_all tensor([[51, 43, 27, 35,  0, 51, 35, 19, 19, 11,  0, 27, 19, 19, 27, 35, 13, 17, 28, 33, 47, 45, 46, 21, 20, 42, 62,  8, 15, 10, 18,  5,  2, 16, 49, 34]])
    # 
    #     # n * c * k -> n * k * c
    #     # N, C * local_size0 * local_size1, H * W - >  #N, H * W, C * local_size0 * local_size1
    #     X = torch.transpose(X, 1, 2)
    # 
    #     # N, H * W, C * local_size0 * local_size1 -> N, keypoint_all, C * local_size0 * local_size1
    #     X = batched_index_select(X, dim=1, inds=keypoint_all)
    #     # torch.stack([X[i, :, keypoint_all[i]] for i in range(n)])
    # 
    #     # L2norm, fc layer, -> dim along d
    #     #         X = self.out_layer(X)
    # 
    #     if self.out_layer is None:
    #         if do_normalize:
    #             X = F.normalize(X, p=2, dim=2)
    #         X = X.view(n, -1, net_out_dimension[self.net_type])
    #     else:
    #         X = self.out_layer(X)
    #         if do_normalize:
    #             X = F.normalize(X, p=2, dim=2)
    #         X = X.view(n, -1, self.out_layer.weight.shape[0])
    #     # torch.Size([1, 3, 2048])
    # 
    #     # n * k * output_dimension
    #     # torch.Size([1, 3, 2048])
    # 
    #     if return_map and return_features:
    #         return X, F.normalize(torch.nn.functional.conv2d(m, self.out_layer.weight.unsqueeze(2).unsqueeze(3)), p=2,
    #                               dim=1), m
    #     if return_map:
    #         return X, F.normalize(torch.nn.functional.conv2d(m, self.out_layer.weight.unsqueeze(2).unsqueeze(3)), p=2,
    #                               dim=1)
    #     if return_features:
    #         return X, m
    #     return X

    def cuda(self, device=None):
        self.net.cuda(device=device)

        if self.out_layer is not None:
            self.out_layer.cuda(device=device)

        return self


if __name__ == '__main__':
    a = torch.ones(1, 3, 224, 300)
    kps = torch.tensor([[(36, 40), (90, 80)]])
    net = NetE2E('resnet50_pre', (1, 1), 128, reduce_function=MergeReduce('mean'))

    print(net.forward(a, kps).shape)

# kps
