import torch
import torch.nn as nn
import torch.nn.functional as F

'''
Model classes
'''

class MLPPowerIndex(nn.Module):
    '''
    MLP Archicture used for Banzhaf and Shapley
    '''
    def __init__(self, input_size, hidden_size, output_size, n_layers, drop_prob):
        super().__init__()
        layers = []
        for i in range(n_layers-1):
            layers += [
                nn.Linear(input_size, hidden_size),
                nn.ReLU(inplace=True),
                nn.Dropout(drop_prob)
            ]
            input_size = hidden_size

        # Add output layer
        layers += [
            nn.Linear(input_size, output_size),
            nn.Softmax(dim=1) # Normalize payoffs
            ]
        self.layers = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.layers(x)


class MLPLeastcore(nn.Module):
    '''
    MLP used for the least core
    '''
    def __init__(self, input_size, hidden_size, out_size_payoffs, drop_prob):
        super().__init__()

        self.lin1 = nn.Linear(input_size, hidden_size)
        self.lin2 = nn.Linear(hidden_size, hidden_size)
        self.out_payoffs = nn.Linear(hidden_size, out_size_payoffs)
        self.out_eps = nn.Linear(hidden_size, 1)
        self.softmax = nn.Softmax(dim=-1)
        self.sigmoid = nn.Sigmoid()
        self.dropout = nn.Dropout(p=drop_prob)

    def forward(self, x):
        x = F.relu(self.lin1(x))
        x = self.dropout(x)
        x = F.relu(self.lin2(x))
        # Separate payoff vector and epsilon
        x_payoffs = self.out_payoffs(x)
        x_eps = self.out_eps(x)
        return self.softmax(x_payoffs), self.sigmoid(x_eps)


class MultinomialRegression(nn.Module):
    ''' Multinomial Logistic Regression '''
    def __init__(self, input_size, output_size):
        super(MultinomialRegression, self).__init__()
        self.linear = nn.Linear(input_size, output_size)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        out = self.linear(x)
        return self.softmax(out)
