import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models


#########################
# MultiScale 3scale Org, Body, and part(head)
############################
class VGG19_Norm_Triplet_3Scale_LateFusion(nn.Module):
    def __init__(self, train_ncls, att_ndim, debug=False):
        super(VGG19_Norm_Triplet_3Scale_LateFusion, self).__init__()

        self._margin = 0.8  # 0.8
        self.vgg19_org = VGG19_CONV_AVGPool()
        self.vgg19_body = VGG19_CONV_AVGPool()
        self.vgg19_part = VGG19_CONV_AVGPool()

        self.FC_emb_org = nn.Linear(512, att_ndim)
        self.FC_cls_org = nn.Linear(512, att_ndim)

        self.FC_emb_body = nn.Linear(512, att_ndim)
        self.FC_cls_body = nn.Linear(512, att_ndim)

        self.FC_emb_part = nn.Linear(512, att_ndim)
        self.FC_cls_part = nn.Linear(512, att_ndim)

        # initialize FC layers
        self.Linear_Init(self.FC_emb_org)
        self.Linear_Init(self.FC_cls_org)
        self.Linear_Init(self.FC_emb_body)
        self.Linear_Init(self.FC_cls_body)
        self.Linear_Init(self.FC_emb_part)
        self.Linear_Init(self.FC_cls_part)

        # Guassion unitball initialization
        class_agent_weight_org = torch.randn(train_ncls, att_ndim)
        class_agent_weight_normalized_org = self.L2_Normalization(class_agent_weight_org, epsilon=1e-12)
        self.class_agents_org = nn.Parameter(class_agent_weight_normalized_org, requires_grad=True)

        class_agent_weight_body = torch.randn(train_ncls, att_ndim)
        class_agent_weight_normalized_body = self.L2_Normalization(class_agent_weight_body, epsilon=1e-12)
        self.class_agents_body = nn.Parameter(class_agent_weight_normalized_body, requires_grad=True)

        class_agent_weight_part = torch.randn(train_ncls, att_ndim)
        class_agent_weight_normalized_part = self.L2_Normalization(class_agent_weight_part, epsilon=1e-12)
        self.class_agents_part = nn.Parameter(class_agent_weight_normalized_part, requires_grad=True)

    def Linear_Init(self, FC_layer):
        nn.init.normal_(FC_layer.weight, 0, 0.01)
        nn.init.constant_(FC_layer.bias, 0)

    def L2_Normalization(self, data, epsilon=1e-6):
        r_row_norm = torch.rsqrt(torch.sum(data.pow(2), dim=1) + epsilon)
        r_row_norm_mat = r_row_norm.unsqueeze(1).repeat(1, data.shape[1])
        return torch.mul(data, r_row_norm_mat)

    def L2_InterDistance(self, x1, x2):
        # x1_num, dim ;; x2_num, dim
        assert x1.shape[1] == x2.shape[1]
        (x1_num, dim), x2_num = x1.shape, x2.shape[0]
        x1_tiled = x1.unsqueeze(1).repeat((1, x2_num, 1)).view(-1, dim)
        x2_tiled = x2.unsqueeze(0).repeat((x1_num, 1, 1)).view(-1, dim)
        return torch.sum((x1_tiled - x2_tiled).pow(2), dim=1).view(x1_num, x2_num)

    def forward(self, im_org, im_body, im_part, extract_embed=False):
        if not extract_embed:
            return self.forward_net(im_org, im_body, im_part)
        else:
            # this is for testing
            return self.extract_embed(im_org, im_body, im_part)

    def forward_net(self, im_org, im_body, im_part):
        x = self.vgg19_org(im_org)
        x = F.dropout(x, p=0.5, training=self.training)
        x_emb_org = self.FC_emb_org(x)
        x_cls_org = self.FC_cls_org(x)

        x = self.vgg19_body(im_body)
        x = F.dropout(x, p=0.5, training=self.training)
        x_emb_body = self.FC_emb_body(x)
        x_cls_body = self.FC_cls_body(x)

        x = self.vgg19_part(im_part)
        x = F.dropout(x, p=0.5, training=self.training)
        x_emb_part = self.FC_emb_part(x)
        x_cls_part = self.FC_cls_part(x)

        x_emb_comb = x_emb_org + x_emb_body + x_emb_part

        x_cls_normalized_org = self.L2_Normalization(x_cls_org)
        class_agents_normalized_org = self.L2_Normalization(self.class_agents_org)

        x_cls_normalized_body = self.L2_Normalization(x_cls_body)
        class_agents_normalized_body = self.L2_Normalization(self.class_agents_body)

        x_cls_normalized_part = self.L2_Normalization(x_cls_part)
        class_agents_normalized_part = self.L2_Normalization(self.class_agents_part)

        # return x_emb, x_cls, triplet_loss, pos_dist, neg_dist
        return x_emb_comb, x_cls_normalized_org, x_cls_normalized_body, x_cls_normalized_part, \
               class_agents_normalized_org, class_agents_normalized_body, class_agents_normalized_part, \
               self.L2_InterDistance(x_cls_normalized_org, class_agents_normalized_org), \
               self.L2_InterDistance(x_cls_normalized_body, class_agents_normalized_body), \
               self.L2_InterDistance(x_cls_normalized_part, class_agents_normalized_part)

    def extract_embed(self, im_org, im_body, im_part):
        x = self.vgg19_org(im_org)
        # x = F.dropout(x, p=0.5, training=self.training)
        x_emb_org = self.FC_emb_org(x)
        x_cls_org = self.FC_cls_org(x)

        x = self.vgg19_body(im_body)
        # x = F.dropout(x, p=0.5, training=self.training)
        x_emb_body = self.FC_emb_body(x)
        x_cls_body = self.FC_cls_body(x)

        x = self.vgg19_part(im_part)
        # x = F.dropout(x, p=0.5, training=self.training)
        x_emb_part = self.FC_emb_part(x)
        x_cls_part = self.FC_cls_part(x)

        x_emb_comb = x_emb_org + x_emb_body + x_emb_part

        x_cls_normalized_org = self.L2_Normalization(x_cls_org)
        x_cls_normalized_body = self.L2_Normalization(x_cls_body)
        x_cls_normalized_part = self.L2_Normalization(x_cls_part)
        return x_emb_comb, torch.cat((x_cls_normalized_org, x_cls_normalized_body, x_cls_normalized_part), dim=1)


class VGG19_CONV_AVGPool(nn.Module):
    def __init__(self):
        super(VGG19_CONV_AVGPool, self).__init__()
        original_model = models.vgg19(pretrained=True).features
        self.features = nn.Sequential(
            *list(original_model.children())[:-1]
        )
        self.avgpool = nn.AvgPool2d(14, stride=1)
        # utils.set_trainable(self.features, requires_grad=False)
    def forward(self, x):
        x = self.features(x)  # 14 * 14 * 512
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        return x