import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
import math
from copy import deepcopy
from torchvision import models
from fedsd2c.fedsd2c_utils import lr_cosine_policy, clip
from torch.utils.data import DataLoader


class OutputHook:
    def __init__(self, module):
        self.r_feature = None
        self.hook = module.register_forward_hook(self.hook_fn)

    def hook_fn(self, module, inputs, output):
        # B C W H
        self.r_feature = output

    def close(self):
        self.hook.remove()


def feat_match_comp(model, x, factor, iteration, dataset, device):
    # IPC * factor, C, H, W
    s = x.shape
    img_real = x.clone()
    downsampler = nn.AvgPool2d(int(math.sqrt(factor)), stride=int(math.sqrt(factor)))

    compressed_x = downsampler(x).clone()
    compressed_x = compressed_x.to(device)
    compressed_x.requires_grad = True

    model = deepcopy(model)
    model.eval()
    for p in model.parameters():
        p.requires_grad = False

    if isinstance(model, models.resnet.ResNet):
        output_hook = OutputHook(model.layer4)
    else:
        raise NotImplementedError("Only support ResNet now.")

    criterion = nn.MSELoss(reduction="sum")
    optimizer = torch.optim.AdamW([compressed_x], lr=1e-2, betas=(0.5, 0.9), weight_decay=1e-4)
    lr_scheduler = lr_cosine_policy(1e-2, 0, iteration)

    losses = []
    with torch.no_grad():
        model = model.to(device)
        img_real = img_real.to(device)
        output_real = model(img_real)
        real_features = output_hook.r_feature

    for i in range(iteration):
        lr_scheduler(optimizer, i, i)
        img_syn = F.interpolate(compressed_x, size=(s[2], s[3]), mode='bilinear')
        output_syn = model(img_syn)
        syn_features = output_hook.r_feature

        loss = criterion(syn_features, real_features)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        compressed_x.data = clip(compressed_x.data, dataset=dataset)
        losses.append(loss.item())

    if len(losses) == 0:
        losses.append(0)
    ret_x = F.interpolate(compressed_x.data.clone(), size=(s[2], s[3]), mode='bilinear')
    ret_x.requires_grad = False
    return ret_x, losses


def feat_mean_comp(model, dataset, device, pos="conv1"):
    model = deepcopy(model)
    model.eval()
    if isinstance(model, models.resnet.ResNet):
        if pos == "conv1":
            print("conv1 feat mean")
            output_hook = OutputHook(model.maxpool)
        elif pos == "layer4":
            print("layer4 feat mean")
            output_hook = OutputHook(model.layer4)
        else:
            raise NotImplementedError(f"{pos}")
    else:
        raise NotImplementedError("Only support ResNet now.")

    feat_mean = 0
    loader = DataLoader(dataset, batch_size=128, shuffle=False, num_workers=4)
    with torch.no_grad():
        model = model.to(device)
        for x, _ in loader:
            x = x.to(device)
            model(x)
            feat_mean += output_hook.r_feature.sum(dim=0)
    feat_mean /= len(dataset)
    return feat_mean
        

def mse_sim(model, src, target):
    model.eval()
    with torch.no_grad():
        if isinstance(model, models.resnet.ResNet):
            feat_1 = model.conv1(src)
            feat_1 = model.bn1(feat_1)
            feat_1 = model.relu(feat_1)
            feat_1 = model.maxpool(feat_1)
            
            feat_2 = model.conv1(target)
            feat_2 = model.bn1(feat_2)
            feat_2 = model.relu(feat_2)
            feat_2 = model.maxpool(feat_2)
        else:
            raise NotImplementedError("Only support ResNet now.")
    feat_1 = feat_1.unsqueeze(1).cpu()  # (N1, 1, C, H, W)
    feat_2 = feat_2.unsqueeze(0).cpu()  # (1, N2, C, H, W)
    sim = torch.mean((feat_1 - feat_2) ** 2, dim=(2, 3, 4)) # (N1, N2)
    sim = sim / torch.max(sim)
    if sim.shape[0] == sim.shape[1]:
        sim.fill_diagonal_(1)
    return sim


def cos_sim(src, target, mean_x):
    diff1 = (src - mean_x)
    diff1 = diff1.view(diff1.size(0), -1)
    diff2 = (target - mean_x)
    diff2 = diff2.view(diff2.size(0), -1)
    
    sim = 1 - diff1 @ diff2.t() / (diff1.norm(dim=1).view(-1, 1) @ diff2.norm(dim=1).view(1, -1))

    if sim.shape[0] == sim.shape[1]:
        sim.fill_diagonal_(1)
    return sim


def feat_extract(model, x, device, pos="conv1"):
    if isinstance(model, models.resnet.ResNet):
        if pos == "conv1":
            output_hook = OutputHook(model.maxpool)
        elif pos == "layer4":
            output_hook = OutputHook(model.layer4)
        else:
            raise NotImplementedError(f"{pos}")
    else:
        raise NotImplementedError("Only support ResNet now.")

    with torch.no_grad():
        model = model.to(device)
        x = x.to(device)
        model(x)
        feat = output_hook.r_feature
    output_hook.close()
    return feat