import hydra
import hydra.experimental
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import RobertaTokenizer,T5Tokenizer
from omegaconf import DictConfig,OmegaConf
import os
from configure_dataloader import DataLoaderHandler

from model import Paraphraser, RADARModel
from radar_discriminator import RADARDiscriminator
from radar_trainerloop import RADARTrainerLoop 
import random
import logging


def set_seed(_hashed_seed):
    random.seed(_hashed_seed)
    np.random.seed(_hashed_seed)
    torch.manual_seed(_hashed_seed)
    torch.cuda.manual_seed(_hashed_seed)
    torch.cuda.manual_seed_all(_hashed_seed)

@hydra.main(config_path="vicuna_7b_config",config_name="config")
def main(config=None):
    os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
    set_seed(config.training.random_seed)
    dataloader_handler = DataLoaderHandler(config)
    model = RADARModel(config)

    # Register your transformer for decoding
    discriminator_tokenizer = RobertaTokenizer.from_pretrained(config.model.discriminator_name_or_path)
    paraphraser_tokenizer = T5Tokenizer.from_pretrained(config.model.paraphraser_name_or_path)

    trainer = RADARTrainerLoop(config=config,
                                discriminator_tokenizer=discriminator_tokenizer, 
                                paraphraser_tokenizer=paraphraser_tokenizer,
                                model=model, 
                                train_dataloader_fn=dataloader_handler.train_dataloader,
                                valid_dataloader_fn=dataloader_handler.valid_dataloader
                                )
                            
    trainer.train()

if __name__ == "__main__":
    main()