import pytorch_lightning as pl
from torch import optim, nn
import os
import torch
import numpy as np
from typing import List
import pdb

def init_weights(m):
    if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.xavier_uniform_(m.weight, 0.001)
#         m.bias.data.fill_(0.01)
class EncoderDecoderModule(pl.LightningModule):
    def __init__(self,
                 net: nn.Module,
                 batch_size=1000,
                 learning_rate=1e-3,
                ):
        super().__init__()
        self.learning_rate = learning_rate
        self.batch_size=batch_size
        self.net = net
        self.apply(init_weights)
        
    def training_step(self, batch, batch_idx):
        x, y = batch
        yhat = self.net(x)
        # Train only on not NaNs
        yidx = torch.isfinite(y)
        loss = nn.functional.mse_loss(yhat[yidx], y[yidx])
        # Logging to TensorBoard by default
        self.log("train_loss", loss)
        return loss

#     def configure_optimizers(self):
#         optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
#         return optimizer
    
    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        yhat = self.net(x)
        # Train only on not NaNs
        yidx = torch.isfinite(y)
        loss = nn.functional.mse_loss(yhat[yidx], y[yidx])
        self.log("val_loss", loss, prog_bar=True)
        return {"val_loss": loss}
    
    def validation_epoch_end(self, outputs):
        #todo
        return None
    
    def test_step(self, batch, batch_idx):
        x, y = batch
        yhat = self.net(x)
        # Train only on not NaNs
        yidx = torch.isfinite(y)
        loss = nn.functional.mse_loss(yhat[yidx], y[yidx])
        self.log("test_loss", loss)
        
class ConvEncoderDecoder(nn.Module):
    def __init__(
        self,
        e_hidden_dims: List[int]=[24, 24, 12, 12],
        e_strides: List[int] =[1,2,1,2],
        d_in_channels: int = 16,
        d_in_size: int = 5, 
        d_hidden_dims: List[int]=[16, 8, 4],
        d_strides: List[int] = [2, 2, 1],
        latent_dim=200,
        output_kernel_size=1,
    ):
        super().__init__()
        cur_dim = 200
        in_channels = 11
        self.d_in_channels = d_in_channels
        self.d_in_size = d_in_size
        modules = []

        # Build Encoder
        for h_dim, stride in zip(e_hidden_dims, e_strides):
            modules.append(
                nn.Sequential(
                    nn.Conv1d(in_channels, out_channels=h_dim,
                              kernel_size= 3, stride= stride, padding  = 1),
                    nn.BatchNorm1d(h_dim),
                    nn.LeakyReLU())
            )
            in_channels = h_dim
            cur_dim = int(np.ceil(cur_dim/stride))

        self.encoder_linear = nn.Linear(e_hidden_dims[-1]*cur_dim, latent_dim)
        self.encoder = nn.Sequential(*modules)
        
        self.decoder_linear = nn.Linear(latent_dim, d_in_channels*d_in_size*d_in_size)
#         d_hidden_dims = [16, 8, 4] # 5x5 -> 10x10 -> 20x20
        modules = []
        in_channels = d_in_channels
        for h_dim, stride in zip(d_hidden_dims, d_strides):
            if stride == 2:
                output_padding = int(np.floor(stride/2))
                conv_layer = nn.ConvTranspose2d(in_channels,
                    h_dim,
                    kernel_size=3,
                    stride = stride,
                    padding=1,
                    output_padding=output_padding
                )
            else:
                conv_layer = nn.Conv2d(
                    in_channels, 
                    h_dim,
                    kernel_size=3,
                    stride=stride,
                    padding=1
                )
                                       
            modules.append(
                nn.Sequential(
                    conv_layer,
                    nn.BatchNorm2d(h_dim),
                    nn.LeakyReLU())
            )
            in_channels = h_dim
#         for i in range(len(d_hidden_dims) - 1):
#             modules.append(
#                 nn.Sequential(
#                     nn.ConvTranspose2d(d_hidden_dims[i],
#                                        d_hidden_dims[i + 1],
#                                        kernel_size=3,
#                                        stride = 2,
#                                        padding=1,
#                                        output_padding=1),
#                     nn.BatchNorm2d(d_hidden_dims[i + 1]),
#                     nn.LeakyReLU())
#             )

        modules.append(nn.Conv2d(d_hidden_dims[-1], 2, kernel_size=output_kernel_size, stride=1, padding=output_kernel_size//2))


        self.decoder = nn.Sequential(*modules)
    def encode(self, x):
        bs = x.size(0)
        z = self.encoder(x[...,:200]).view(bs, -1)
        z = self.encoder_linear(z)
        return z
    
    def decode(self, z):
        bs = z.size(0)
        z = self.decoder(self.decoder_linear(z).view(bs, self.d_in_channels, self.d_in_size, self.d_in_size))
        return z
    
    def forward(self, x):
        z = self.encode(x)
        yhat = self.decode(z)
        return yhat
    

class ConvEncoderLinDecoder(nn.Module):
    def __init__(self,
                 e_hidden_dims: List[int]=[24, 24, 12, 12],
                 e_strides: List[int] =[1,2,1,2],
                 d_hidden_dims: List[int]=[300, 500, 500],
                 latent_dim=300,
                 batch_size=1000,
                 learning_rate=1e-3,
                ):
        super().__init__()
#         self.data_dir = data_dir or os.getcwd()
#         self.config = config
        self.learning_rate = learning_rate
        self.batch_size=batch_size
#         if config is not None and "lr" in config:
#             self.learning_rate = config["lr"]
        self.e_hidden_dims = e_hidden_dims
        self.e_strides = e_strides
        self.d_hidden_dims = d_hidden_dims
        cur_dim = 200
        in_channels = 11
        modules = []
        
        # Build Encoder
        for h_dim, stride in zip(e_hidden_dims, e_strides):
            modules.append(
                nn.Sequential(
                    nn.Conv1d(in_channels, out_channels=h_dim,
                              kernel_size= 3, stride= stride, padding  = 1),
                    nn.BatchNorm1d(h_dim),
                    nn.LeakyReLU())
            )
            in_channels = h_dim
            cur_dim = int(np.ceil(cur_dim/stride))

        self.encoder_linear = nn.Linear(e_hidden_dims[-1]*cur_dim, latent_dim)
        self.encoder = nn.Sequential(*modules)
        # 16x5x5 -> 16x10x10 -> 8x20x20 -> 4x20x20
        # Build Decoder
        
        modules = []
        in_dim = latent_dim
        for h_dim in self.d_hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Linear(in_dim,
                              h_dim,
                              bias=False,
                             ),
                    nn.LeakyReLU()
                )
            )
            in_dim = h_dim

        modules.append(nn.Linear(in_dim, 2*20*20, bias=False))

        self.decoder = nn.Sequential(*modules)
        
    def encode(self, x):
        bs = x.size(0)
        z = self.encoder(x[...,:200]).view(bs, -1)
        z = self.encoder_linear(z)
        return z

    def decode(self, z):
        bs = z.size(0)
        z = self.decoder(z).view(bs, 2, 20, 20)
        return z
    
    def forward(self, x):
        z = self.encode(x)
        yhat = self.decode(z)
        return yhat
    
class MLP(pl.LightningModule):
    def __init__(
        self, 
        learning_rate=5e-5,
        batch_size=16, 
        hidden_dims: List[int] =[256, 256, 256, 256], 
        in_size: int = 231*11,
        num_at: int = 2,
    ):
        super().__init__()
#         self.data_dir = data_dir or os.getcwd()
#         self.config = config
        self.learning_rate = learning_rate
        self.batch_size = batch_size
        self.hidden_dims = hidden_dims
        self.in_size = in_size
        self.num_at = num_at
#         if config is not None and "lr" in config:
#             self.learning_rate = config["lr"]
        
        module = []
        cur_dim = self.in_size
        for h in hidden_dims:
            module.append(nn.Sequential(nn.Linear(cur_dim, h), 
#                                         nn.BatchNorm1d(h), 
                                        nn.LeakyReLU()
                                       )
                         )
            cur_dim=h
        module.append(nn.Linear(hidden_dims[-1], self.num_at*20*20))
        self.model = nn.Sequential(*module)
#         self.model = []
        
#         self.model += [nn.Linear(100*12,256), nn.PReLU(),
#                       nn.Linear(256, 256), nn.PReLU(),
#                       nn.Linear(256, 256), nn.PReLU(),
#                       nn.Linear(256, 256), nn.PReLU(),
#                       nn.Linear(256, 20*20)]
#         self.model = nn.Sequential(*self.model)
        
    def training_step(self, batch, batch_idx):
        x, y = batch
        bs = x.size(0)
        x = x.view(bs, -1)
        yhat = self.model(x).view(bs, self.num_at, 20, 20)
        # Train only on not NaNs
        yidx = torch.isfinite(y)
        loss = nn.functional.mse_loss(yhat[yidx], y[yidx])
        # Logging to TensorBoard by default
        self.log("train_loss", loss)
        return loss

#     def configure_optimizers(self):
#         optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
#         return optimizer
    
    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        bs = x.size(0)
        x = x.view(bs, -1)
        yhat = self.model(x).view(bs, self.num_at, 20, 20)
        # Train only on not NaNs
        yidx = torch.isfinite(y)
        loss = nn.functional.mse_loss(yhat[yidx], y[yidx])
        self.log("val_loss", loss, prog_bar=True)
        return {"val_loss": loss}
    
    def validation_epoch_end(self, outputs):
        #todo
        return None
    
    def test_step(self, batch, batch_idx):
        x, y = batch
        bs = x.size(0)
        x = x.view(bs, -1)
        yhat = self.model(x).view(bs, self.num_at, 20, 20)
        # Train only on not NaNs
        yidx = torch.isfinite(y)
        loss = nn.functional.mse_loss(yhat[yidx], y[yidx])
        self.log("test_loss", loss)
#         return {"val_loss": loss}