import torch
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F

class resnet18_vfl(nn.Module):
    def __init__(self, out_dim=10):
        super(resnet18_vfl,self).__init__()
        self.backbone = models.resnet18(pretrained=False)
        self.num_ftrs = self.backbone.fc.in_features
        self.out_dim = out_dim
        self.clf = nn.Linear(self.num_ftrs, self.out_dim)

        self.extractor = nn.Sequential(*list(self.backbone.children())[:-2])
    def classifier(self, concat):
        if concat is True:
            return nn.Sequential(nn.Flatten(), nn.Linear(self.num_ftrs*2, self.out_dim))
        else:
            return nn.Sequential(nn.Flatten(), nn.Linear(self.num_ftrs, self.out_dim))


    def forward(self,x):
        y = self.extractor(x)
        return y

class Flatten(nn.Module):
        def __init__(self):
            super(Flatten, self).__init__()
            
        def forward(self, x):
            x = x.view(x.size(0), -1)
            return x

class resnet18_vfl_split(nn.Module):
    def __init__(self, out_dim=10):
        super(resnet18_vfl_split,self).__init__()
        net = models.resnet18(pretrained=False)
        net.fc = nn.Linear(net.fc.in_features, out_dim)
        modules = list(net.children())[:-4]
        self.encoder = nn.Sequential(*modules)
        modules = list(net.children())[-4:-1]
        self.clf = nn.Sequential(*[*modules])
        self.clf_adv = nn.Sequential(*[Flatten(), nn.Linear(1024, 512), nn.ReLU(), nn.Linear(512, 256), nn.ReLU(), nn.Linear(256, out_dim)])
        self.clf_adv_sim = nn.Sequential(*[Flatten(), nn.Linear(1024, out_dim)])

    def forward(self,x):
        y = self.encoder(x)


class resnet18(nn.Module):
    def __init__(self, out_dim=10):
        super(resnet18,self).__init__()
        self.backbone = models.resnet18(pretrained=False)
        self.num_ftrs = self.backbone.fc.in_features
        self.out_dim = out_dim
        self.clf = nn.Linear(self.num_ftrs, self.out_dim)

        self.extractor = nn.Sequential(*list(self.backbone.children())[:-2])


    def forward(self,x):
        x = self.extractor(x)
        x = x.view(x.shape[0], -1)
        y = self.clf(x)
        return y


class lenet5_vfl(nn.Module):
    def __init__(self, out_dim=10):
        super(lenet5_vfl,self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(80, 120)
        self.fc2 = nn.Linear(120, 84)
        self.num_ftrs = 84
        self.out_dim = out_dim
        self.clf = nn.Linear(self.num_ftrs, self.out_dim)

    def classifier(self, concat):
        if concat is True:
            return nn.Sequential(nn.Flatten(), nn.Linear(self.num_ftrs*2, self.out_dim))
        else:
            return nn.Sequential(nn.Flatten(), nn.Linear(self.num_ftrs, self.out_dim))


    def forward(self,x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(x.shape[0], -1)
        x = F.relu(self.fc1(x))
        y = F.relu(self.fc2(x))
        return y

class lenet5(nn.Module):
    def __init__(self, out_dim=10):
        super(lenet5,self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(400, 120)
        self.fc2 = nn.Linear(120, 84)
        self.clf = nn.Linear(84, out_dim)



    def forward(self,x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(x.shape[0], -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        y = self.clf(x)
        return y




class ResNetSimCLR_vfl(nn.Module):

    def __init__(self, base_model, out_dim, class_dim):
        super(ResNetSimCLR_vfl, self).__init__()
        
        self.class_dim = class_dim

        self.resnet_dict = {"resnet18": models.resnet18(pretrained=False, num_classes=out_dim),
                            "resnet50": models.resnet50(pretrained=False, num_classes=out_dim)}

        self.backbone = self._get_basemodel(base_model)
        self.dim_mlp = self.backbone.fc.in_features

        # add mlp projection head
        self.backbone.fc = nn.Sequential(nn.Linear(int(self.dim_mlp), self.dim_mlp), nn.ReLU(), self.backbone.fc)

    def _get_basemodel(self, model_name):
        try:
            model = self.resnet_dict[model_name]
        except KeyError:
            raise InvalidBackboneError(
                "Invalid backbone architecture. Check the config file and pass one of: resnet18 or resnet50")
        else:
            return model

    def classifier(self, concat):
        if concat is True:
            return nn.Sequential(nn.Flatten(), nn.Linear(self.dim_mlp*2, self.class_dim))
        else:
            return nn.Sequential(nn.Flatten(), nn.Linear(self.dim_mlp, self.class_dim))

    def extractor(self):
        return nn.Sequential(*list(self.backbone.children())[:-1])


    def load(self, path):
        checkpoint = torch.load(path)
        self.load_state_dict(checkpoint['state_dict'])

    def forward(self, x):
        return self.backbone(x)

#net = ResNetSimCLR_vfl("resnet18", 128, 10)
#net = resnet18_vfl()
#path = '/home/js905/intern2022/SimCLR/runs/May31_21-32-01_hl279-cmp-03.egr.duke.edu/checkpoint_0200.pth.tar'
#checkpoint = torch.load(path)
#net.load_state_dict(checkpoint['state_dict'])

#net.load('/home/js905/intern2022/SimCLR/runs/May31_21-32-01_hl279-cmp-03.egr.duke.edu/checkpoint_0200.pth.tar')
#extractor = net.extractor()
#for k,v in extractor.named_parameters():
#    print(k)
