import torch
from torch import nn
from torch.nn import Linear

def count_pars(net):
    return sum(p.numel() for p in net.parameters() if p.requires_grad)

class MLP(torch.nn.Module):
    
    def __init__(self, 
                 input_dim=64,
                 hidden_dims=[32, 32],
                 output_dim=1,
                 final_nonlinearity=torch.nn.Identity()):
        
        super().__init__()
        
        self.relu = nn.ReLU()
        self._layers = nn.Sequential(nn.Linear(input_dim, hidden_dims[0]), self.relu)
        for i in range(len(hidden_dims) - 1):
            self._layers.append(nn.Linear(hidden_dims[i], hidden_dims[i+1]))
            self._layers.append(self.relu)
        self._layers.append(nn.Linear(hidden_dims[-1], output_dim))
        self._final_nl = final_nonlinearity

    def forward(self, x):

        return self._final_nl(self._layers(x))

class CNN(torch.nn.Module):

    """
    Convoultional neural network for images/data on grids.
    """

    def __init__(self, 
                 N=100,
                 n_channels=1, 
                 hidden_layer_channels=3,
                 conv_kernel_size=4,
                 pool_kernel_size=2,
                 final_ff=[32, 16],
                 device="cpu"):

        #assert len(final_ff) == 2

        super().__init__()

        conv1_out_dim = N - conv_kernel_size + 1
        #pool1_out_dim = int((conv1_out_dim - pool_kernel_size) / pool_kernel_size + 1)
        conv2_out_dim = conv1_out_dim - pool_kernel_size + 1
        print(conv2_out_dim)

        self.conv1 = nn.Conv2d(n_channels, 
                               hidden_layer_channels, 
                               conv_kernel_size, device=device)
        self.conv2 = nn.Conv2d(hidden_layer_channels, 
                               1, 
                               pool_kernel_size, device=device)
        #self.pool = nn.MaxPool2d(pool_kernel_size,
        #                         pool_kernel_size)
        #self.fc1 = nn.Linear(hidden_layer_channels * pool1_out_dim**2,
        #                     final_ff[0], device=device)
        self.fc1 = nn.Linear(conv2_out_dim **2,
                             final_ff[0], device=device)
        #self.fc2 = nn.Linear(final_ff[0],
        #                     final_ff[1], device=device)
        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()

    def forward(self, x):
        # Output shape of this will be (B, hidden_layer_channels, )
        #print(x)
        x = self.conv1(x)
        #print("Covd", x)
        #x = self.relu(x)
        #print("ReLUd conv", x)
        #x = self.pool(x)
        x = self.conv2(x)
        #x = self.relu(x)
        #print(x)
        #print(x, x.shape)
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        #print("Flattened", x)
        x = self.tanh(self.fc1(x))
        #x = self.tanh(self.fc2(x))
        #x = self.fc2(x)
        #print(x)
        return x

class RNN(nn.Module):
    
    """
    Elman RNN or GRU followed by feedforward
    """
    
    def __init__(self,
                 input_size=3,
                 hidden_size=32,
                 num_layers=1,
                 final_ff=nn.Identity(),
                 nonlinearity='tanh',
                 flavour='gru'):
        
        super().__init__()

        self.hdim = hidden_size
        if flavour == 'gru':
            self._rnn = nn.GRU(input_size, hidden_size, num_layers,
                               batch_first=True)
        else:
            self._rnn = nn.RNN(input_size, hidden_size, num_layers, nonlinearity=nonlinearity,
                               batch_first=True)
        self._fff = final_ff
        self._rnn_n_pars = count_pars(self._rnn)
        self._fff_n_pars = count_pars(self._fff)

    def forward(self, x, h=None):

        if h is None:
            out, _ = self._rnn(x)
        else:
            out, _ = self._rnn(x, h)
        return self._fff(out)
