import argparse
from experiment import *
from main_utils import *
from utils import *


parser = argparse.ArgumentParser(description='Inverse problem of PDE')

parser.add_argument('--config_completer',
                    type=str,
                    default=None,
                    required=True)

parser.add_argument('--epoch_completer',
                    type=int,
                    default=None,
                    required=True)

parser.add_argument('--config_propagator',
                    type=str,
                    default=None,
                    required=True)

parser.add_argument('--epoch_propagator',
                    type=int,
                    default=None,
                    required=True)

parser.add_argument('--device',
                    type=str,
                    default=None,
                    required=True)  

parser.add_argument('--seed', 
                    type=int, 
                    default=2023)

args = parser.parse_args()


if __name__ == "__main__":    
    set_seed(args.seed)
    os.environ['CUDA_VISIBLE_DEVICES'] = args.device
    torch.distributed.init_process_group("nccl")
    local_rank = torch.distributed.get_rank()
    world_size = torch.distributed.get_world_size()
    device = local_rank
    torch.cuda.set_device(device)

    config_completer_file = "configs/" + args.config_completer + ".jsonc"
    config_completer = Configuration(config_completer_file)

    config_propagator_file = "configs/" + args.config_propagator + ".jsonc"
    config_propagator = Configuration(config_propagator_file)
    
    _, _, train_dataloader, \
    _, _, val_dataloader, \
    _, _, test_dataloader, \
    transformer, masker_completer, _, \
    completer, _, _, _ \
    = get_data_model(config_completer, device)
    
    _, _, _, \
    _, _, _, \
    _, _, _, \
    _, masker_propagator, poser, \
    propagator, _, _, _ \
    = get_data_model(config_propagator, device)
    
    checkpoint_completer_dir = "../runs/" + args.config_completer + "/checkpoint/"
    completer.load_state_dict(torch.load(checkpoint_completer_dir + str(args.epoch_completer) + ".pt", map_location="cuda:{}".format(device)))

    checkpoint_propagator_dir = "../runs/" + args.config_propagator + "/checkpoint/"
    propagator.load_state_dict(torch.load(checkpoint_propagator_dir + str(args.epoch_propagator) + ".pt", map_location="cuda:{}".format(device)))

    infer_metric, infer_side1_metric, infer_side2_metric = \
    infer(test_dataloader, transformer, masker_completer, masker_propagator, poser, 
          completer, config_completer.model.name, propagator, config_propagator.model.name, local_rank, world_size)
    
    infer_metric = infer_metric.item()
    infer_side1_metric = infer_side1_metric.item()
    infer_side2_metric = infer_side2_metric.item()
    
    print("Test Metric: {}\tTest Side1 Metric: {}\tTest Side2 Metric: {}"\
        .format(infer_metric, infer_side1_metric, infer_side2_metric))
    torch.distributed.destroy_process_group()
