import torch
import torch.nn as nn
import numpy as np

class Style_Adaptation(nn.Module):
    def __init__(self, eps=1e-5):
        super().__init__()
        self.eps = eps

    def forward(self, concate_x):
        N, C, H, W = concate_x.size()
        bs = int(N/3)
        if self.training:
            x_s1, x_s2, x_t = concate_x.chunk(3, dim=0)
            # get source content in target style
            mean_s1 = x_s1.mean((2,3), keepdim=False)
            mean_s2 = x_s2.mean((2,3), keepdim=False)
            var_s1 = x_s1.var((2,3), keepdim=False)
            var_s2 = x_s2.var((2,3), keepdim=False)

            mean_t = x_t.mean((2,3), keepdim=False)
            var_t = x_t.var((2,3), keepdim=False)

            sig_s1 = (var_s1 + self.eps).sqrt()
            sig_s2 = (var_s2 + self.eps).sqrt()
            sig_t = (var_t + self.eps).sqrt()
            
            distance1 = torch.pow((mean_s1.view(bs, 1, C) - mean_t.view(1,-1,C)), 2) + torch.pow(sig_s1.view(bs, 1, C), 2) + torch.pow(sig_t.view(1,-1,C), 2) - 2 * sig_s1.view(bs, 1, C) * sig_t.view(1,-1,C)
            distance2 = torch.pow((mean_s2.view(bs, 1, C) - mean_t.view(1,-1,C)), 2) + torch.pow(sig_s2.view(bs, 1, C), 2) + torch.pow(sig_t.view(1,-1,C), 2) - 2 * sig_s2.view(bs, 1, C) * sig_t.view(1,-1,C)
            
            distance1 = distance1.mean(2, keepdim=False)
            distance2 = distance2.mean(2, keepdim=False)
            gamma1 = torch.exp(1.0 / (1.0 + distance1))
            gamma1 = nn.Softmax(dim=1)(gamma1)
            gamma2 = torch.exp(1.0 / (1.0 + distance2))
            gamma2 = nn.Softmax(dim=1)(gamma2)
            
            idx_swap = torch.randperm(bs)
            alpha = torch.rand(bs, 1).cuda()
            mix_mu_s1 = alpha * mean_s1 + (1-alpha) * torch.matmul(gamma1, mean_t)[idx_swap]
            mix_mu_s2 = alpha * mean_s2 + (1-alpha) * torch.matmul(gamma2, mean_t)[idx_swap]
            
            mix_sig_s1 = alpha * sig_s1 + (1-alpha) * torch.matmul(gamma1, sig_t)[idx_swap]
            mix_sig_s2 = alpha * sig_s2 + (1-alpha) * torch.matmul(gamma2, sig_t)[idx_swap]
            
            x_st1 = ((x_s1 - mean_s1[:, :, None, None]) / sig_s1[:, :, None, None]) * mix_sig_s1[:, :, None, None] + mix_mu_s1[:, :, None, None]
            x_st2 = ((x_s2 - mean_s2[:, :, None, None]) / sig_s2[:, :, None, None]) * mix_sig_s2[:, :, None, None] + mix_mu_s2[:, :, None, None]
            
            x_st1 = x_st1.view(bs, C, H, W)
            x_st2 = x_st2.view(bs, C, H, W)
            x_t = x_t.view(bs, C, H, W)

            concate_x = torch.cat((x_st1, x_st2, x_t), dim=0)

            return concate_x

        return concate_x