import torchvision
import torch.nn as nn
import torch.nn.functional as F

from config import cfg

class VGG16GAP(nn.Module):
    def __init__(self, num_classes = 1000):
        '''imagenet-subset 和 imagenet-mini是12类'''
        super(VGG16GAP, self).__init__()
        self.backbone = torchvision.models.vgg16_bn(pretrained=cfg.model.pretrained)
        if num_classes != 1000:
            self.backbone.classifier = nn.Sequential(
                nn.Linear(512 * 7 * 7, num_classes)
            )

    def forward(self, x):
        return self.backbone(x)

    def freeze_conv(self):
        for param in self.backbone.parameters():
            param.requires_grad = False

        for param in self.backbone.classifier.parameters():
            param.requires_grad = True

    def unfreeze(self):
        for param in self.backbone.parameters():
            param.requires_grad = True
