import torch
import torch.nn as nn
from copy import deepcopy
from queue import Queue
from torch.autograd import Variable
from einops import rearrange
from .utils import MyQueue

class YearbookNetwork(nn.Module):
    def __init__(self, args, num_input_channels, num_classes):
        super(YearbookNetwork, self).__init__()
        self.args = args
        self.enc = nn.Sequential(self.conv_block(num_input_channels, 32), self.conv_block(32, 32),
                                 self.conv_block(32, 32), self.conv_block(32, 32))
        self.hid_dim = 32
        self.classifier = nn.Linear(32, num_classes)

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

    def forward(self, x):
        x = self.enc(x)
        x = torch.mean(x, dim=(2, 3))
        return self.classifier(x)









class YearbookNetwork_for_Ours(nn.Module):
    def __init__(self, args, num_input_channels, num_classes):
        super(YearbookNetwork_for_Ours, self).__init__()
        self.args = args
        self.enc = nn.Sequential(self.conv_block(num_input_channels, 32), self.conv_block(32, 32),
                                 self.conv_block(32, 32), self.conv_block(32, 32))
        self.feature_dim = 32
        self.classifier = nn.Linear(self.feature_dim, num_classes, bias=False)
        self.knowledge_pool = MyQueue(maxsize=args.trainer.len_queue)
        self.DM_trainsample_pool = MyQueue(maxsize=args.trainer.len_DM_pool)
        self.eps = 1e-6

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

    def memorize(self, W):
        # W.shape: [C, D]
        self.knowledge_pool.put_item(W)

    def foward_encoder(self, x):
        f = self.enc(x)
        f = torch.mean(f, dim=(2, 3))
        return f

    def foward(self, x):
        f = self.enc(x)
        f = torch.mean(f, dim=(2, 3))
        logits = self.classifier(f)
        return f, logits

    def get_parameters(self, lr):
        params_list = []
        params_list.extend([
                {"params": self.enc.parameters(), 'lr': 1 * lr},
                {"params": self.classifier.parameters(), 'lr': 1 * lr},
            ]
        )
        return params_list







