import argparse
from experiment import *
from main_utils import *
from utils import *


parser = argparse.ArgumentParser(description='Inverse problem of PDE')

parser.add_argument('--config',
                    type=str,
                    default=None,
                    required=True)

parser.add_argument('--device',
                    type=str,
                    default=None,
                    required=True)  

parser.add_argument('--epoch',
                    type=int,
                    default=None,
                    required=True)

parser.add_argument('--seed', 
                    type=int, 
                    default=2023)

args = parser.parse_args()


if __name__ == "__main__":    
    set_seed(args.seed)
    os.environ['CUDA_VISIBLE_DEVICES'] = args.device
    torch.distributed.init_process_group("nccl")
    local_rank = torch.distributed.get_rank()
    world_size = torch.distributed.get_world_size()
    device = local_rank
    torch.cuda.set_device(device)

    config_file = "configs/" + args.config + ".jsonc"
    config = Configuration(config_file)
    
    _, _, train_dataloader, \
    _, _, val_dataloader, \
    _, _, test_dataloader, \
    transformer, masker, poser, \
    model, _, _, _ \
    = get_data_model(config, device)
    
    checkpoint_dir = "../runs/" + args.config + "/checkpoint/"

    model.load_state_dict(torch.load(checkpoint_dir + str(args.epoch) + ".pt", map_location="cuda:{}".format(device)))
    
    if config.role == "propagator":
        test_metric, test_side1_metric, test_side2_metric = \
        test_propagator(test_dataloader, transformer, masker, poser, model, config.model.name, local_rank, world_size)
    elif config.role == "completer":
        test_metric, test_side1_metric, test_side2_metric = \
        test_completer(test_dataloader, transformer, masker, poser, model, config.model.name, local_rank, world_size)
    else:
        raise NotImplementedError("Invalid Role!")

    test_metric = test_metric.item()
    test_side1_metric = test_side1_metric.item()
    test_side2_metric = test_side2_metric.item()
    
    print("Test Metric: {}\tTest Side1 Metric: {}\tTest Side2 Metric: {}"\
        .format(test_metric, test_side1_metric, test_side2_metric))
    torch.distributed.destroy_process_group()
