from models import register_model
import torch.nn as nn
import torch
from models.base_model import BaseModel
from models.transformer_encoder_input import TransformerEncoderInput
import torch.nn.functional as F

@register_model("cnn_wav_model")
class CNNWavModel(BaseModel):
    def __init__(self):
        super(CNNWavModel, self).__init__()

    #def forward(self, inputs):
    #    out = self.linear_out(inputs)
    #    return out
    #    #return inputs.sum()# out

    def build_model(self, cfg, input_dim=None):
        self.cfg = cfg
        self.n_downsample_convs = 5 #NOTE hardcode params.n_downsample_convs
        self.output_dim = 1
        self.dropout_p = 0.1 #NOTE hardcode
        self.input_dim=1
        assert self.n_downsample_convs <= 5

        self.conv11_down2 = nn.Conv1d(in_channels=self.input_dim, out_channels=128, kernel_size=5, padding=2, stride=2)
        self.bn11 = nn.BatchNorm1d(128)

        self.conv12_down2 = nn.Conv1d(in_channels=128, out_channels=256, kernel_size=3, padding=1, stride=2)
        self.bn12 = nn.BatchNorm1d(256)

        self.conv21 = nn.Conv1d(in_channels=256, out_channels=256, kernel_size=3, padding=1, stride=1)
        self.bn21 = nn.BatchNorm1d(256)
        self.conv22 = nn.Conv1d(in_channels=256, out_channels=256, kernel_size=3, padding=1, stride=1)
        self.bn22 = nn.BatchNorm1d(256)
        self.conv3_down2 = nn.Conv1d(in_channels=256, out_channels=512, kernel_size=3, padding=1, stride=2)
        self.bn3 = nn.BatchNorm1d(512)
        
        self.conv41 = nn.Conv1d(in_channels=512, out_channels=512, kernel_size=3, padding=1, stride=1)
        self.bn41 = nn.BatchNorm1d(512)
        self.conv42 = nn.Conv1d(in_channels=512, out_channels=512, kernel_size=3, padding=1, stride=1)
        self.bn42 = nn.BatchNorm1d(512)
        
        self.conv5_down2 = nn.Conv1d(in_channels=512, out_channels=512, kernel_size=3, padding=1, stride=2)
        self.bn5 = nn.BatchNorm1d(512)
        
        self.conv61 = nn.Conv1d(in_channels=512, out_channels=512, kernel_size=3, padding=1, stride=1)
        self.bn61 = nn.BatchNorm1d(512)
        self.conv62 = nn.Conv1d(in_channels=512, out_channels=512, kernel_size=3, padding=1, stride=1)
        self.bn62 = nn.BatchNorm1d(512)

        self.conv7_down2 = nn.Conv1d(in_channels=512, out_channels=512, kernel_size=3, padding=1, stride=2)
        self.bn7 = nn.BatchNorm1d(512)

        in_channels_map = {1:128,2:256,3:512,4:512,5:512}
        in_c = in_channels_map[self.n_downsample_convs]
        
        #self.fc1 = nn.Linear(in_features=int(in_c * (self.input_dim//(2**self.n_downsample_convs))), out_features=128)
        self.fc1 = nn.Linear(in_features=in_c*16, out_features=128)#NOTE: hardcode assuming 512 input at the top

        self.linear_out = nn.Linear(in_features=128, out_features=self.output_dim)
        self.dropout = nn.Dropout(p=self.dropout_p)

    def final_layer(self, inp):
        x = self.fc1(inp.view(inp.size()[0], -1))
        #x = self.fc1(inp)
        x = self.dropout(x)
        return self.linear_out(x)

    def forward(self, inp):
        inp = inp.unsqueeze(1) #[batch, channels, len]
        x = F.relu(self.bn11(self.conv11_down2(inp)))

        if self.n_downsample_convs <= 1:
            return self.final_layer(x)
            
        x = F.relu(self.bn12(self.conv12_down2(x)))

        if self.n_downsample_convs <= 2:
            return self.final_layer(x)
 
        res = x
        x = F.relu(self.bn21(self.conv21(x)))
        x = F.relu(self.bn22(self.conv22(x)))
        x = x + res
        x = F.relu(self.bn3(self.conv3_down2(x)))

        if self.n_downsample_convs <= 3:
            return self.final_layer(x)
 
        res = x
        x = F.relu(self.bn41(self.conv41(x)))
        x = F.relu(self.bn42(self.conv42(x)))
        x = x + res
        x = F.relu(self.bn5(self.conv5_down2(x)))

        if self.n_downsample_convs <= 4:
            return self.final_layer(x)
 
        res = x
        x = F.relu(self.bn61(self.conv61(x)))
        x = F.relu(self.bn62(self.conv62(x)))
        x = x + res
        x = F.relu(self.bn7(self.conv7_down2(x)))

        x = self.final_layer(x)
        return x
