import os
import torch
import torch.nn as nn
from transformers import GPT2Model, GPT2Config

from tqdm import tqdm

import warnings
import math

from src.components.transformer import TransformerCLF
from src.components.blocks.mamba_encoder import MambaModel
from src.components.blocks.dss_encoder import DSS
from src.components.blocks.retnet_encoder import Retnet
from src.components.transformer_utils import AbsolutePositionalEncoding


import pdb




def build_clf(conf):

	if conf.model == "san":
		model = TransformerModelCLF(
			n_words=conf.n_words,
			n_positions=conf.n_positions,
			n_embd=conf.n_embd,
			n_layer=conf.n_layer,
			n_head=conf.n_head,
			n_out=conf.n_out,
		)
	elif conf.model == "mysan":
		model = TransformerCLF(
			n_words=conf.n_words,		
			d_model=conf.n_embd,
			n_layer=conf.n_layer,
			n_head=conf.n_head,
			n_out=conf.n_out,
		)

	elif conf.model == "mamba":
		model = MambaCLF(
			n_words=conf.n_words,
			d_model=conf.n_embd,
			n_layer=conf.n_layer,
			n_out=conf.n_out,
		)

	elif conf.model == "retnet":
		model = RetnetCLF(
			n_words=conf.n_words,		
			d_model=conf.n_embd,
			n_layer=conf.n_layer,
			n_head=conf.n_head,
			n_out=conf.n_out,
		)

	
	elif conf.model == "dss":
		model = DSSCLF(
			n_words=conf.n_words,		
			d_model=conf.n_embd,
			n_layer=conf.n_layer,
			n_out=conf.n_out,
		)
		
	elif conf.model == "lstm" or conf.model == "gru":
		model = LSTM(
			n_words=conf.n_words,
			n_embd=conf.n_embd,
			n_layer=conf.n_layer,
			rnn_type= conf.model,
			n_out=conf.n_out,
		)
	else:
		raise NotImplementedError

	return model





class TransformerModelCLF(nn.Module):
	def __init__(self, n_words, n_positions, n_embd=128, n_layer=12, n_head=4, n_out = 1):
		super(TransformerModelCLF, self).__init__()
		configuration = GPT2Config(
			n_positions=2 * n_positions,
			n_embd=n_embd,
			n_layer=n_layer,
			n_head=n_head,
			resid_pdrop=0.0,
			embd_pdrop=0.0,
			attn_pdrop=0.0,
			use_cache=False,
		)
		self.name = f"gpt2_embd={n_embd}_layer={n_layer}_head={n_head}"

		self.n_positions = n_positions
		self.n_words = n_words
		self.n_embd = n_embd

		self.encoder= nn.Embedding(n_words, n_embd)
		self._backbone = GPT2Model(configuration)
		self._read_out = nn.Linear(n_embd, n_out)

		print('Transformer Normal Training')

		# Calculate the total number of parameters (including non-trainable)
		total_params = sum(p.numel() for p in self.parameters())
		print(f'Total trainable parameters: {total_params:,}')
				



	def forward(self, x):
		# input shape (x): (batch_size, length) [B L]

		embeds = self.encoder(x)
		output = self._backbone(inputs_embeds=embeds).last_hidden_state
		prediction = self._read_out(output)  # (batch_size, seq_len, 1)
		# pdb.set_trace()
		# predict on last tokens
		if prediction.size(-1) == 1:
			return prediction[:, -1, 0]
		else:
			return prediction[:, -1]






class LSTM(nn.Module):
	def __init__(self, n_words, n_embd=128, n_layer=2, rnn_type= 'lstm', n_out =1, pos_encode = True):
		super(LSTM, self).__init__()
		self.name = f"embd={n_embd}_layer={n_layer}"

		self.drop = nn.Dropout(0.0)
		self.rnn_type = rnn_type
		self.n_words = n_words
		self.n_embd = n_embd
		self.n_layer = n_layer

		self.pos_encode = pos_encode

		if pos_encode:
			print('Using Positional Encoding')
			self.pos_encoder = AbsolutePositionalEncoding(n_embd)
		

		# self._read_in = NeuralNetwork(n_dims, 256, self.n_embd)
		if self.rnn_type.lower() == 'lstm':
			self._backbone = nn.LSTM(n_embd, n_embd, n_layer)
		elif self.rnn_type.lower() == 'gru':
			self._backbone = nn.GRU(n_embd, n_embd, n_layer)

		self.encoder = nn.Embedding(n_words, n_embd)
		self._read_out = nn.Linear(n_embd, n_out)
		

		print('All {} parameters are tunable'.format(self.rnn_type))

		# Calculate the total number of parameters (including non-trainable)
		total_params = sum(p.numel() for p in self.parameters())
		print(f'Total trainable parameters: {total_params:,}')


	def forward(self, x,  hidden=None):
		# input shape (x): (batch_size, length) [B L]

		if hidden is None:
			hidden = self.init_hidden(x.size(0))
		
		embeds = self.encoder(x)
		if self.pos_encode:
			embeds = self.pos_encoder(embeds)  # (batch_size, seq_len, d_model)

		embeds = embeds.transpose(0, 1)

		output, hidden = self._backbone(embeds, hidden)

		output = self.drop(output)
		output = output.transpose(0, 1)
		prediction = self._read_out(output)

	
		if prediction.size(-1) == 1:
			return prediction[:, -1, 0]
		else:
			return prediction[:, -1]
	
	
	def init_hidden(self, bsz):
		weight = next(self.parameters())
		if self.rnn_type.lower() == 'lstm':
			return (weight.new_zeros(self.n_layer, bsz, self.n_embd),
					weight.new_zeros(self.n_layer, bsz, self.n_embd))
		else:
			return weight.new_zeros(self.n_layer, bsz, self.n_embd)






class DSSCLF(nn.Module):
	def __init__(self, n_words, d_model, n_layer, n_out =1, dropout=0.0):
		super(DSSCLF, self).__init__()
		self.model_type = 'DSS'
	


		self.name = f"DSS_model={d_model}_layer={n_layer}"
	
		self.d_model = d_model
		self.n_words = n_words
		d_ffn = 4*d_model
		
		self.encoder = nn.Embedding(n_words, d_model)
		self._backbone= DSS(d_model=d_model, n_layers=n_layer, d_ffn=d_ffn)
		self._read_out = nn.Linear(d_model, n_out)

		print('DSS Normal Training: All parameters are tunable')

		# Calculate the total number of parameters (including non-trainable)
		total_params = sum(p.numel() for p in self.parameters())
		print(f'Total trainable parameters: {total_params:,}')




	def forward(self, x):
		# input shape (x): (batch_size, length) [B L]

		embeds = self.encoder(x)

		output = self._backbone(embeds)
		prediction = self._read_out(output)

		if prediction.size(-1) == 1:
			return prediction[:, -1, 0]
		else:
			return prediction[:, -1]
	





class MambaCLF(nn.Module):
	def __init__(self, n_words, d_model, n_layer, n_out =1, dropout=0.0):
		super(MambaCLF, self).__init__()
		self.model_type = 'Mamba'
	


		self.name = f"Mamba_model={d_model}_layer={n_layer}"
	
		self.d_model = d_model
		self.n_words = n_words
		
		
		self.encoder = nn.Embedding(n_words, d_model)
		self._backbone= MambaModel(d_model=d_model, n_layers=n_layer)
		self._read_out = nn.Linear(d_model, n_out)

		print('Mamba Normal Training: All parameters are tunable')

		# Calculate the total number of parameters (including non-trainable)
		total_params = sum(p.numel() for p in self.parameters())
		print(f'Total trainable parameters: {total_params:,}')




	def forward(self, x):
		# input shape (x): (batch_size, length) [B L]

		embeds = self.encoder(x)

		output = self._backbone(embeds)
		prediction = self._read_out(output)

		if prediction.size(-1) == 1:
			return prediction[:, -1, 0]
		else:
			return prediction[:, -1]
	






class RetnetCLF(nn.Module):

	def __init__(self, n_words, d_model, n_layer, n_head, n_out=1, use_decay = True, use_gate=True, pos_encode = True):
		super(RetnetCLF, self).__init__()
		self.model_type = 'Retnet'
	
		# self.pos_encoder = LearnablePositionalEncoding(d_model, dropout)

		self.name = f"retnet_model={d_model}_layer={n_layer}_head={n_head}"
		self.pos_encode = pos_encode

		if pos_encode:
			print('Using Positional Encoding')
			self.pos_encoder = AbsolutePositionalEncoding(d_model)
	
		self.d_model = d_model
		self.n_words = n_words
		d_ffn = 4*d_model
		
		self.encoder= nn.Embedding(n_words, d_model)
		self._backbone= Retnet(d_model=d_model, n_layers=n_layer, num_heads= n_head, d_ffn=d_ffn, use_decay=use_decay, use_gate=use_gate)
		self._read_out = nn.Linear(d_model, n_out)

		print('Retnet Normal Training')

		# Calculate the total number of parameters (including non-trainable)
		total_params = sum(p.numel() for p in self.parameters())
		print(f'Total trainable parameters: {total_params:,}')


	def forward(self, x):
		# input shape (x): (batch_size, length) [B L]

		embeds = self.encoder(x)

		embeds = embeds * math.sqrt(self.d_model)
		if self.pos_encode:
			embeds= self.pos_encoder(embeds)
		# # embeds shape: (batch_size, seq_len, d_model)

		output = self._backbone(embeds)
		prediction = self._read_out(output)
	
		if prediction.size(-1) == 1:
			return prediction[:, -1, 0]
		else:
			return prediction[:, -1]

	




