import torch
import torch.nn as nn

from torchvision.models import resnet18, resnet34, resnet50
from torchvision.models.mobilenet import mobilenet_v2

class net_agg(nn.Module):
    def __init__(self, nets, hidden_size):
        super(net_agg, self).__init__()
        self.device = "cpu"
        self.nets = nn.ModuleList(nets)
        self.linear = nn.Linear(hidden_size, 1)

    def forward(self, input, feature = False):
        returns_list = []
        for net in self.nets:
            returns_list.append(net(input.to(net.device), feature = True)['feature'].cpu())
        returns = {}
        returns['feature'] = torch.cat(returns_list, dim = 1)
        return returns

# fully connected neural network architecture for tabular datasets
class FNN(nn.Module):
    def __init__(self, layers, hidden_size = 32, data = "toy3"):
        super(FNN, self).__init__()
        self.HSIC = 0.
        self.data_ratio = 0.
        self.device = "cpu"
        self.feature_layers = nn.ModuleList()
        if data == "toy3":
            self.feature_layers.append(nn.Linear(3, hidden_size))
        elif data == "energy":
            self.feature_layers.append(nn.Linear(28, hidden_size))
        self.feature_layers.append(nn.ReLU())
        for _ in range(layers - 1):
            self.feature_layers.append(nn.Linear(hidden_size, hidden_size))
            self.feature_layers.append(nn.ReLU())
        self.linear = nn.Linear(hidden_size, 1)
        
    def forward(self, input, feature = False):
        output = input
        for layer in self.feature_layers:
            output = layer(output)
        returns = {}
        if feature:
            returns['feature'] = output
        returns['output'] = self.linear(output)
        return returns

# convolutional neural network architectures for UTKFace (small-scale)
class CNN1_UTK(nn.Module):
    def __init__(self, hidden_size = 64):
        super(CNN1_UTK, self).__init__()
        self.HSIC = 0.
        self.data_ratio = 0.
        self.device = "cpu"
        self.feature_layers = nn.ModuleList()
        self.feature_layers.append(nn.Conv2d(3, 16, kernel_size = 3, stride = 2, padding = 1))
        self.feature_layers.append(nn.BatchNorm2d(16))
        self.feature_layers.append(nn.ReLU())
        self.feature_layers.append(nn.MaxPool2d(kernel_size = 2, stride = 2))
        self.feature_layers.append(nn.Conv2d(16, 32, kernel_size = 3, stride = 2, padding = 1))
        self.feature_layers.append(nn.BatchNorm2d(32))
        self.feature_layers.append(nn.ReLU())
        self.feature_layers.append(nn.MaxPool2d(kernel_size = 2, stride = 2))
        self.feature_layers.append(nn.Flatten())
        self.feature_layers.append(nn.Linear(2048, hidden_size))
        self.linear = nn.Linear(hidden_size, 1)
        
    def forward(self, input, feature = False):
        output = input
        for layer in self.feature_layers:
            output = layer(output)
        returns = {}
        if feature:
            returns['feature'] = output
        returns['output'] = self.linear(output)
        return returns

class CNN2_UTK(nn.Module):
    def __init__(self, hidden_size = 64):
        super(CNN2_UTK, self).__init__()
        self.HSIC = 0.
        self.data_ratio = 0.
        self.device = "cpu"
        self.feature_layers = nn.ModuleList()
        self.feature_layers.append(nn.Conv2d(3, 16, kernel_size = 3, stride = 1, padding = 1))
        self.feature_layers.append(nn.BatchNorm2d(16))
        self.feature_layers.append(nn.ReLU())
        self.feature_layers.append(nn.Conv2d(16, 16, kernel_size = 3, stride = 1, padding = 1))
        self.feature_layers.append(nn.BatchNorm2d(16))
        self.feature_layers.append(nn.ReLU())
        self.feature_layers.append(nn.MaxPool2d(kernel_size = 2, stride = 2))
        self.feature_layers.append(nn.Conv2d(16, 32, kernel_size = 3, stride = 1, padding = 1))
        self.feature_layers.append(nn.BatchNorm2d(32))
        self.feature_layers.append(nn.ReLU())
        self.feature_layers.append(nn.Conv2d(32, 32, kernel_size = 3, stride = 1, padding = 1))
        self.feature_layers.append(nn.BatchNorm2d(32))
        self.feature_layers.append(nn.ReLU())
        self.feature_layers.append(nn.MaxPool2d(kernel_size = 2, stride = 2))
        self.feature_layers.append(nn.Flatten())
        self.feature_layers.append(nn.Linear(32768, hidden_size))
        self.linear = nn.Linear(hidden_size, 1)
        
    def forward(self, input, feature = False):
        output = input
        for layer in self.feature_layers:
            output = layer(output)
        returns = {}
        if feature:
            returns['feature'] = output
        returns['output'] = self.linear(output)
        return returns

class CNN3_UTK(nn.Module):
    def __init__(self, hidden_size = 64):
        super(CNN3_UTK, self).__init__()
        self.HSIC = 0.
        self.data_ratio = 0.
        self.device = "cpu"
        self.feature_layers = nn.ModuleList()
        self.feature_layers.append(nn.Conv2d(3, 32, kernel_size = 3, stride = 2, padding = 1))
        self.feature_layers.append(nn.BatchNorm2d(32))
        self.feature_layers.append(nn.ReLU())
        self.feature_layers.append(nn.MaxPool2d(kernel_size = 2, stride = 2))
        self.feature_layers.append(nn.Conv2d(32, 64, kernel_size = 3, stride = 2, padding = 1))
        self.feature_layers.append(nn.BatchNorm2d(64))
        self.feature_layers.append(nn.ReLU())
        self.feature_layers.append(nn.MaxPool2d(kernel_size = 2, stride = 2))
        self.feature_layers.append(nn.Flatten())
        self.feature_layers.append(nn.Linear(2048*2, hidden_size))
        self.linear = nn.Linear(hidden_size, 1)
        
    def forward(self, input, feature = False):
        output = input
        for layer in self.feature_layers:
            output = layer(output)
        returns = {}
        if feature:
            returns['feature'] = output
        returns['output'] = self.linear(output)
        return returns

class CNN4_UTK(nn.Module):
    def __init__(self, hidden_size = 64):
        super(CNN4_UTK, self).__init__()
        self.HSIC = 0.
        self.data_ratio = 0.
        self.device = "cpu"
        self.feature_layers = nn.ModuleList()
        self.feature_layers.append(nn.Conv2d(3, 32, kernel_size = 3, stride = 1, padding = 1))
        self.feature_layers.append(nn.BatchNorm2d(32))
        self.feature_layers.append(nn.ReLU())
        self.feature_layers.append(nn.Conv2d(32, 32, kernel_size = 3, stride = 1, padding = 1))
        self.feature_layers.append(nn.BatchNorm2d(32))
        self.feature_layers.append(nn.ReLU())
        self.feature_layers.append(nn.MaxPool2d(kernel_size = 2, stride = 2))
        self.feature_layers.append(nn.Conv2d(32, 64, kernel_size = 3, stride = 1, padding = 1))
        self.feature_layers.append(nn.BatchNorm2d(64))
        self.feature_layers.append(nn.ReLU())
        self.feature_layers.append(nn.Conv2d(64, 64, kernel_size = 3, stride = 1, padding = 1))
        self.feature_layers.append(nn.BatchNorm2d(64))
        self.feature_layers.append(nn.ReLU())
        self.feature_layers.append(nn.MaxPool2d(kernel_size = 2, stride = 2))
        self.feature_layers.append(nn.Flatten())
        self.feature_layers.append(nn.Linear(32768*2, hidden_size))
        self.linear = nn.Linear(hidden_size, 1)
        
    def forward(self, input, feature = False):
        output = input
        for layer in self.feature_layers:
            output = layer(output)
        returns = {}
        if feature:
            returns['feature'] = output
        returns['output'] = self.linear(output)
        return returns

# Resnets and MobileNetv2 for RotatedMNIST (large-scale)
class ResNet18_MNIST(nn.Module):
    def __init__(self):
        super(ResNet18_MNIST, self).__init__()
        self.HSIC = 0.
        self.data_ratio = 0.
        self.device = "cpu"
        instant_net = resnet18()
        instant_net_layers = list(instant_net.children())
        instant_net_layers[0] = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        self.feature_layers = nn.ModuleList()
        for layer in instant_net_layers[:-1]:
            self.feature_layers.append(layer)
        self.feature_layers.append(nn.Flatten())
        self.linear = nn.Linear(512, 1)
        
    def forward(self, input, feature = False):
        output = input
        for layer in self.feature_layers:
            output = layer(output)
        returns = {}
        if feature:
            returns['feature'] = output
        returns['output'] = self.linear(output)

        return returns

class ResNet34_MNIST(nn.Module):
    def __init__(self):
        super(ResNet34_MNIST, self).__init__()
        self.HSIC = 0.
        self.data_ratio = 0.
        self.device = "cpu"
        instant_net = resnet34()
        instant_net_layers = list(instant_net.children())
        instant_net_layers[0] = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        self.feature_layers = nn.ModuleList()
        for layer in instant_net_layers[:-1]:
            self.feature_layers.append(layer)
        self.feature_layers.append(nn.Flatten())
        self.linear = nn.Linear(512, 1)
        
    def forward(self, input, feature = False):
        output = input
        for layer in self.feature_layers:
            output = layer(output)
        returns = {}
        if feature:
            returns['feature'] = output
        returns['output'] = self.linear(output)

        return returns

class ResNet50_MNIST(nn.Module):
    def __init__(self):
        super(ResNet50_MNIST, self).__init__()
        self.HSIC = 0.
        self.data_ratio = 0.
        self.device = "cpu"
        instant_net = resnet50()
        instant_net_layers = list(instant_net.children())
        instant_net_layers[0] = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        self.feature_layers = nn.ModuleList()
        for layer in instant_net_layers[:-1]:
            self.feature_layers.append(layer)
        self.feature_layers.append(nn.Flatten())
        self.linear = nn.Linear(2048, 1)
        
    def forward(self, input, feature = False):
        output = input
        for layer in self.feature_layers:
            output = layer(output)
        returns = {}
        if feature:
            returns['feature'] = output
        returns['output'] = self.linear(output)

        return returns

class MobileNetv2_MNIST(nn.Module):
    def __init__(self):
        super(MobileNetv2_MNIST, self).__init__()
        self.HSIC = 0.
        self.data_ratio = 0.
        self.device = "cpu"
        instant_net = mobilenet_v2()
        instant_net_layers = list(instant_net.children())
        instant_net_layers[0][0][0] = nn.Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        self.feature_layers = list(instant_net.children())[0]
        self.feature_layers.append(nn.Flatten())
        self.linear = nn.Linear(1280, 1)
        
    def forward(self, input, feature = False):
        output = input
        for layer in self.feature_layers:
            output = layer(output)
        returns = {}
        if feature:
            returns['feature'] = output
        returns['output'] = self.linear(output)

        return returns

# Resnets and MobileNetv2 for IMDB (large-scale)
class ResNet18_IMDB(nn.Module):
    def __init__(self):
        super(ResNet18_IMDB, self).__init__()
        self.HSIC = 0.
        self.data_ratio = 0.
        self.device = "cpu"
        instant_resnet = resnet18()
        instant_resnet_layers = list(instant_resnet.children())
        self.feature_layers = nn.ModuleList()
        for layer in instant_resnet_layers[:-1]:
            self.feature_layers.append(layer)
        self.feature_layers.append(nn.Flatten())
        self.linear = nn.Linear(512, 1)
        
    def forward(self, input, feature = False):
        output = input
        for layer in self.feature_layers:
            output = layer(output)
        returns = {}
        if feature:
            returns['feature'] = output
        returns['output'] = self.linear(output)

        return returns

class ResNet34_IMDB(nn.Module):
    def __init__(self):
        super(ResNet34_IMDB, self).__init__()
        self.HSIC = 0.
        self.data_ratio = 0.
        self.device = "cpu"
        instant_resnet = resnet34()
        instant_resnet_layers = list(instant_resnet.children())
        self.feature_layers = nn.ModuleList()
        for layer in instant_resnet_layers[:-1]:
            self.feature_layers.append(layer)
        self.feature_layers.append(nn.Flatten())
        self.linear = nn.Linear(512, 1)
        
    def forward(self, input, feature = False):
        output = input
        for layer in self.feature_layers:
            output = layer(output)
        returns = {}
        if feature:
            returns['feature'] = output
        returns['output'] = self.linear(output)

        return returns

class ResNet50_IMDB(nn.Module):
    def __init__(self):
        super(ResNet50_IMDB, self).__init__()
        self.HSIC = 0.
        self.data_ratio = 0.
        self.device = "cpu"
        instant_resnet = resnet50()
        instant_resnet_layers = list(instant_resnet.children())
        self.feature_layers = nn.ModuleList()
        for layer in instant_resnet_layers[:-1]:
            self.feature_layers.append(layer)
        self.feature_layers.append(nn.Flatten())
        self.linear = nn.Linear(2048, 1)
        
    def forward(self, input, feature = False):
        output = input
        for layer in self.feature_layers:
            output = layer(output)
        returns = {}
        if feature:
            returns['feature'] = output
        returns['output'] = self.linear(output)

        return returns

class MobileNetv2_IMDB(nn.Module):
    def __init__(self):
        super(MobileNetv2_IMDB, self).__init__()
        self.HSIC = 0.
        self.data_ratio = 0.
        self.device = "cpu"
        instant_mobilenet = mobilenet_v2()
        self.feature_layers = list(instant_mobilenet.children())[0]
        self.feature_layers.append(nn.AvgPool2d(2))
        self.feature_layers.append(nn.Flatten())
        self.linear = nn.Linear(1280, 1)
        
    def forward(self, input, feature = False):
        output = input
        for layer in self.feature_layers:
            output = layer(output)
        returns = {}
        if feature:
            returns['feature'] = output
        returns['output'] = self.linear(output)

        return returns
