import torch
import torchvision

from models import DeepMILAttModel, CAMIL


def build_MIL_model(args, pos_weight=None):

    ce_criterion = torch.nn.BCEWithLogitsLoss(reduction='mean', pos_weight=pos_weight)
    
    if args.model_name == 'abmil':
        return DeepMILAttModel(
            input_shape=args.input_shape,
            feat_ext_name=args.feat_ext_name,
            pool_name='att',
            pool_kwargs={
                'att_dim': args.pool_att_dim,
                'alpha': args.alpha,
                'smooth_mode' : args.smooth_mode,
                'smooth_where' : args.smooth_where,
                'spectral_norm' : args.spectral_norm,
            },
            ce_criterion=ce_criterion
        )
    elif args.model_name == 'transformer_abmil':
        return DeepMILAttModel(
            input_shape=args.input_shape,
            feat_ext_name=args.feat_ext_name,
            pool_name='att',
            transformer_encoder_kwargs={
                'att_dim': args.transf_att_dim,
                'num_heads': args.transf_num_heads,
                'num_layers': args.transf_num_layers,
                'use_ff': args.transf_use_ff,
                'dropout': args.transf_dropout,
                'smooth_steps': args.transf_smooth_steps,
            },
            pool_kwargs={
                'att_dim': args.pool_att_dim,
                'alpha': args.alpha,
                'smooth_mode' : args.smooth_mode,
                'smooth_where' : args.smooth_where,
                'spectral_norm' : args.spectral_norm,
            },
            ce_criterion=ce_criterion
        )
    elif args.model_name == 'camil':
            if len(args.input_shape) > 1:
                raise ValueError('CAMIL only supports 1D input')
            num_features = args.input_shape[0]
            return CAMIL(
                num_features, 
                ce_criterion=ce_criterion
                )
    else:
            raise NotImplementedError

class ResNetTrunk(torchvision.models.resnet.ResNet):
	def __init__(self, *args, **kwargs):
		super().__init__(*args, **kwargs)
		del self.fc  # remove FC layer
	
	def forward(self, x):
		x = self.conv1(x)
		x = self.bn1(x)
		x = self.relu(x)

		x = self.maxpool(x)

		x = self.layer1(x)
		x = self.layer2(x)
		x = self.layer3(x)
		x = self.layer4(x)

		x = self.avgpool(x)

		x = x.view(x.size(0), -1)

		return x

def get_pretrained_url(key):
    URL_PREFIX = "https://github.com/lunit-io/benchmark-ssl-pathology/releases/download/pretrained-weights"
    model_zoo_registry = {
        "BT": "bt_rn50_ep200.torch",
        "MoCoV2": "mocov2_rn50_ep200.torch",
        "SwAV": "swav_rn50_ep200.torch",
    }
    pretrained_url = f"{URL_PREFIX}/{model_zoo_registry.get(key)}"
    return pretrained_url

def resnet50_ssl(pretrained, progress, key, **kwargs):
    model = ResNetTrunk(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], **kwargs)
    if pretrained:
        pretrained_url = get_pretrained_url(key)
        verbose = model.load_state_dict(
            torch.hub.load_state_dict_from_url(pretrained_url, progress=progress)
        )
        print(verbose)
    return model

def get_ssl_transforms():
	transforms = torchvision.transforms.Compose(
		[
			torchvision.transforms.ToTensor(),
			torchvision.transforms.Normalize(mean=[ 0.70322989, 0.53606487, 0.66096631 ], std=[ 0.21716536, 0.26081574, 0.20723464 ]),
		]
	)
	return transforms

def load_torchvision_model(model_name, use_imagenet_pretrained_weights=True):
    model = None
    transforms = None
    if model_name == 'mobilenet_v2':
        if use_imagenet_pretrained_weights:
            weights = 'IMAGENET1K_V2'
        else:
            weights = None
        model = torchvision.models.mobilenet_v2(weights = weights)
        model = torch.nn.Sequential(*list(model.children())[:-1], torch.nn.Flatten())
        # transforms = torchvision.transforms.Compose([
        #     torchvision.models.MobileNet_V2_Weights.IMAGENET1K_V1.transforms(antialias=True)
        # ])
        transforms = torchvision.transforms.Normalize(
            mean=torch.tensor([0.485, 0.456, 0.406]),
            std=torch.tensor([0.229, 0.224, 0.225])
        )
        # n_feat = 1280
    elif model_name == 'mobilenet_v3_large':
        if use_imagenet_pretrained_weights:
            weights = 'IMAGENET1K_V2'
        else:
            weights = None
        model = torchvision.models.mobilenet_v3_large(weights = weights)
        model = torch.nn.Sequential(*list(model.children())[:-1], torch.nn.Flatten())
        # transforms = torchvision.transforms.Compose([
        #     torchvision.models.MobileNet_V3_Large_Weights.IMAGENET1K_V1.transforms(antialias=True)
        # ])
        transforms = torchvision.transforms.Normalize(
            mean=torch.tensor([0.485, 0.456, 0.406]),
            std=torch.tensor([0.229, 0.224, 0.225])
        )
        # n_feat = 960
    elif model_name == 'resnet18':
        if use_imagenet_pretrained_weights:
            weights = torchvision.models.ResNet18_Weights.IMAGENET1K_V1
        else:
            weights = None
        model = torchvision.models.resnet18(weights = weights)
        model = torch.nn.Sequential(*list(model.children())[:-1], torch.nn.Flatten())
        # transforms = torchvision.transforms.Compose([
        #     torchvision.models.ResNet18_Weights.IMAGENET1K_V1.transforms(antialias=True)
        # ])
        transforms = torchvision.transforms.Normalize(
            mean=torch.tensor([0.485, 0.456, 0.406]),
            std=torch.tensor([0.229, 0.224, 0.225])
        )
        # n_feat = 512
    elif model_name == 'resnet50':
        if use_imagenet_pretrained_weights:
            weights = torchvision.models.ResNet50_Weights.IMAGENET1K_V1
        else:
            weights = None
        model = torchvision.models.resnet50(weights = weights)
        model = torch.nn.Sequential(*list(model.children())[:-1], torch.nn.Flatten())
        # transforms = torchvision.transforms.Compose([
        #     torchvision.models.ResNet50_Weights.IMAGENET1K_V1.transforms(antialias=True)
        # ])
        transforms = torchvision.transforms.Normalize(
            mean=torch.tensor([0.485, 0.456, 0.406]),
            std=torch.tensor([0.229, 0.224, 0.225])
        )
    elif model_name == 'resnet50_bt':
        model = resnet50_ssl(pretrained=True, progress=False, key="BT")
        transforms = torchvision.transforms.Normalize(
            mean=[ 0.70322989, 0.53606487, 0.66096631 ],
            std=[ 0.21716536, 0.26081574, 0.20723464 ]
        )
    elif model_name == 'resnet50_mocov2':
        model = resnet50_ssl(pretrained=True, progress=False, key="MoCoV2")
        transforms = torchvision.transforms.Normalize(
            mean=[ 0.70322989, 0.53606487, 0.66096631 ],
            std=[ 0.21716536, 0.26081574, 0.20723464 ]
        )
    elif model_name == 'resnet50_swav':
        model = resnet50_ssl(pretrained=True, progress=False, key="SwAV")
        transforms = torchvision.transforms.Normalize(
            mean=[ 0.70322989, 0.53606487, 0.66096631 ],
            std=[ 0.21716536, 0.26081574, 0.20723464 ]
        )
    else:
        raise NotImplementedError
    
    return model, transforms