import argparse
from experiment import *
from dataset import *
from model import *
from utils import *
import shutil


parser = argparse.ArgumentParser(description="Neural Operator Backbone: LNO")
parser.add_argument("--config", type=str, default=None, required=True)
parser.add_argument("--device", type=str, default=None, required=True)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--exp", type=str, default=None, required=True)
arg = parser.parse_args()


if __name__ == "__main__":
    set_seed(arg.seed)
    os.environ['CUDA_VISIBLE_DEVICES'] = arg.device
    torch.distributed.init_process_group("nccl")
    local_rank = torch.distributed.get_rank()
    world_size = torch.distributed.get_world_size()
    device = local_rank
    torch.cuda.set_device(device)

    config_file = "configs/" + arg.config + ".jsonc"
    config = Configuration(config_file)
    
    model_attr = dict()
    if "_time" in arg.config:
        model_attr["time"] = True
    else:
        model_attr["time"] = False
    
    train_dataloader, val_dataloader, _, \
    transformer, model, loss, optimizer, scheduler \
    = get_model_data(config, model_attr, device)
    
    arg.exp = arg.config + arg.exp
    log_dir = "../experiment/" + arg.exp + "/log/"
    checkpoint_dir = "../experiment/" + arg.exp + "/checkpoint/"
    src_dir = "../experiment/" + arg.exp + "/src/"
    save_para(arg, config)
    
    if not os.path.exists(src_dir):
        os.makedirs(src_dir)
    for obj in os.listdir("."):
        if os.path.isfile(obj):
            shutil.copy(obj, src_dir + obj)
    
    if "_single" in config.model.name:
        model_attr["single"] = True
    else:
        model_attr["single"] = False
        
    if model_attr["time"]:
        train_time(
            train_dataloader,
            val_dataloader,
            transformer,
            model,
            model_attr,
            loss,
            optimizer,
            scheduler,
            local_rank,
            world_size,
            config.train.grad_clip,
            config.train.epoch,
            config.train.log_print_interval_epoch,
            config.train.model_save_interval_epoch,
            log_dir,
            checkpoint_dir
            )
    else:
        train(
            train_dataloader,
            val_dataloader,
            transformer,
            model,
            model_attr,
            loss,
            optimizer,
            scheduler,
            local_rank,
            world_size,
            config.train.grad_clip,
            config.train.epoch,
            config.train.log_print_interval_epoch,
            config.train.model_save_interval_epoch,
            log_dir,
            checkpoint_dir
            )

    torch.distributed.destroy_process_group()
