import torch
from utils import *
from model import *
from main_utils import *
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter


def train_propagator(train_dataloader,
        val_dataloader,
        test_dataloader,
        transformer,
        masker,
        poser,
        model,
        model_name,
        loss_fn,
        optimizer,
        scheduler,
        local_rank,
        world_size,
        grad_clip,
        epoch,
        log_print_interval_epoch,
        model_save_interval_epoch,
        log_dir,
        checkpoint_dir):
    
    device = local_rank
    
    if local_rank == 0:
        print("Number of Propagator Parameters: {}".format(get_num_params(model)))
        # print(model)
        print("Start Training...")
        writer = SummaryWriter(log_dir)
        checker = Checkpoint(checkpoint_dir, model, device)
        epoch_history = []
        lr_history = []
        train_loss_history = []
        val_loss_history = []
        test_metric_history = []
        test_side1_metric_history = []
        test_side2_metric_history = []
    
    for i in range(epoch):
        torch.distributed.barrier()
        train_dataloader.sampler.set_epoch(i)
        train_loss = 0

        for _, data in enumerate(tqdm(train_dataloader)):
            x, y, _ = data
            if "GNOT" in model_name:
                x, y, ob = data_preprocess_propagator_GNOT(masker, x, y, device)
            elif "LNO" in model_name:
                x, y, ob = data_preprocess_propagator_LNO(masker, x, y, device)
            elif "DeepONet" in model_name:
                x, y, ob = data_preprocess_propagator_DeepONet(masker, x, y, device)
            else:
                raise NotImplementedError("Invalid Propagator Name !")

            loss = torch.zeros((1)).to(device)
            optimizer.zero_grad()
            model.train()

            for j in range(0, len(ob)):
                if "time" in model_name:
                    t = torch.ones((x[j].shape[0], 1)) * j
                    t = t.to(device)
                    res = model(x[j], ob[j], t)
                else:
                    res = model(x[j], ob[j])
                
                loss = loss + loss_fn(res, y[j])

            loss = loss / len(ob)              
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            optimizer.step()
            scheduler.step()
            
            train_batch_loss = torch.tensor(loss.item()).to(device)
            torch.distributed.all_reduce(train_batch_loss)
            train_loss = train_loss + train_batch_loss / world_size

        train_loss = train_loss / len(train_dataloader)
        train_loss = train_loss.item()
        val_loss = val_propagator(val_dataloader, transformer, masker, model, model_name, loss_fn, local_rank, world_size)
        val_loss = val_loss.item()
        test_metric, test_side1_metric, test_side2_metric = test_propagator(test_dataloader, transformer, masker, poser, model, model_name, local_rank, world_size)
        test_metric = test_metric.item()
        test_side1_metric = test_side1_metric.item()
        test_side2_metric = test_side2_metric.item()
        
        if local_rank == 0:
            if (i + 1) % log_print_interval_epoch == 0:
                writer.add_scalar("Learning Rate", optimizer.state_dict()['param_groups'][0]['lr'], i+1)
                writer.add_scalar("Train Loss", train_loss)
                writer.add_scalar("Val Loss", val_loss)
                writer.add_scalar("Test Metric", test_metric)
                writer.add_scalar("Test Side1 Metric", test_side1_metric)
                writer.add_scalar("Test Side2 Metric", test_side2_metric)
                
                epoch_history.append(i+1)
                lr_history.append(optimizer.state_dict()['param_groups'][0]['lr'])
                train_loss_history.append(train_loss)
                val_loss_history.append(val_loss)
                test_metric_history.append(test_metric)
                test_side1_metric_history.append(test_side1_metric)
                test_side2_metric_history.append(test_side2_metric)
                
                print("Epoch: {}\tLearning Rate :{}\tTrain Loss: {}\tVal Loss: {}\tTest Metric: {}\tTest Side1 Metric: {}\tTest Side2 Metric: {}"\
                    .format(i+1, optimizer.state_dict()['param_groups'][0]['lr'], train_loss, val_loss, test_metric, test_side1_metric, test_side2_metric))
            
            if (i + 1) % model_save_interval_epoch == 0:
                checker.save(i+1)
    
    if local_rank == 0:
        writer.close()
        logger(log_dir, 
               [epoch_history, lr_history, train_loss_history, val_loss_history, test_metric_history, test_side1_metric_history, test_side2_metric_history], 
               ["Epoch", "LR", "Train_Loss", "Val_Loss", "Test_Metric", "Test_Side1_Metric", "Test_Side2_Metric"])
        print("Finish Training !")


def train_completer(train_dataloader,
        val_dataloader,
        test_dataloader,
        transformer,
        masker,
        poser,
        model,
        model_name,
        loss_fn,
        optimizer,
        scheduler,
        local_rank,
        world_size,
        grad_clip,
        epoch,
        log_print_interval_epoch,
        model_save_interval_epoch,
        log_dir,
        checkpoint_dir):

    device = local_rank
    
    if local_rank == 0:
        print("Number of Completer Parameters: {}".format(get_num_params(model)))
        # print(model)
        print("Start Training...")
        writer = SummaryWriter(log_dir)
        checker = Checkpoint(checkpoint_dir, model, device)
        epoch_history = []
        lr_history = []
        train_loss_history = []
        val_loss_history = []
        test_metric_history = []
        test_side1_metric_history = []
        test_side2_metric_history = []
    
    for i in range(epoch):
        torch.distributed.barrier()
        train_dataloader.sampler.set_epoch(i)
        train_loss = 0

        for _, data in enumerate(tqdm(train_dataloader)):
            x, y, _ = data
            if "DeepONet" in model_name:
                x, y, ob = data_preprocess_completer_DeepONet(masker, x, y, device)
            elif "GNOT" in model_name:
                x, y, ob = data_preprocess_completer_GNOT(masker, x, y, device)
            elif "LNO" in model_name:
                x, y, ob = data_preprocess_completer_LNO(masker, x, y, device)
            else:
                raise NotImplementedError("Invalid Completer Name")
           
            loss = torch.zeros((1)).to(device)
            optimizer.zero_grad()
            model.train()
            res = model(x, ob)
            loss = loss_fn(res, y)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            optimizer.step()
            scheduler.step()
            
            train_batch_loss = torch.tensor(loss.item()).to(device)
            torch.distributed.all_reduce(train_batch_loss)
            train_loss = train_loss + train_batch_loss / world_size

        train_loss = train_loss / len(train_dataloader)
        train_loss = train_loss.item()
        val_loss = val_completer(val_dataloader, transformer, masker, model, model_name, loss_fn, local_rank, world_size)
        val_loss = val_loss.item()
        test_metric, test_side1_metric, test_side2_metric = test_completer(test_dataloader, transformer, masker, poser, model, model_name, local_rank, world_size)
        test_metric = test_metric.item()
        test_side1_metric = test_side1_metric.item()
        test_side2_metric = test_side2_metric.item()
        
        if local_rank == 0:
            if (i + 1) % log_print_interval_epoch == 0:
                writer.add_scalar("Learning Rate", optimizer.state_dict()['param_groups'][0]['lr'], i+1)
                writer.add_scalar("Train Loss", train_loss)
                writer.add_scalar("Val Loss", val_loss)
                writer.add_scalar("Test Metric", test_metric)
                writer.add_scalar("Test Side1 Metric", test_side1_metric)
                writer.add_scalar("Test Side2 Metric", test_side2_metric)
                
                epoch_history.append(i+1)
                lr_history.append(optimizer.state_dict()['param_groups'][0]['lr'])
                train_loss_history.append(train_loss)
                val_loss_history.append(val_loss)
                test_metric_history.append(test_metric)
                test_side1_metric_history.append(test_side1_metric)
                test_side2_metric_history.append(test_side2_metric)
                
                print("Epoch: {}\tLearning Rate :{}\tTrain Loss: {}\tVal Loss: {}\tTest Metric: {}\tTest Side1 Metric: {}\tTest Side2 Metric: {}"\
                    .format(i+1, optimizer.state_dict()['param_groups'][0]['lr'], train_loss, val_loss, test_metric, test_side1_metric, test_side2_metric))
            
            if (i + 1) % model_save_interval_epoch == 0:
                checker.save(i+1)
    
    if local_rank == 0:
        writer.close()
        logger(log_dir, 
               [epoch_history, lr_history, train_loss_history, val_loss_history, test_metric_history, test_side1_metric_history, test_side2_metric_history], 
               ["Epoch", "LR", "Train_Loss", "Val_Loss", "Test_Metric", "Test_Side1_Metric", "Test_Side2_Metric"])
        print("Finish Training !")


def val_propagator(val_dataloader,
        transformer,
        masker,
        model,
        model_name,
        loss_fn,
        local_rank,
        world_size):
    
    with torch.no_grad():
        device = local_rank
        val_loss = 0
        
        for _, data in enumerate(val_dataloader):
            x, y, _ = data
            if "GNOT" in model_name:
                x, y, ob = data_preprocess_propagator_GNOT(masker, x, y, device)
            elif "LNO" in model_name:
                x, y, ob = data_preprocess_propagator_LNO(masker, x, y, device)
            elif "DeepONet" in model_name:
                x, y, ob = data_preprocess_propagator_DeepONet(masker, x, y, device)
            else:
                raise NotImplementedError("Invalid Propagator Name !")
            
            loss = torch.zeros((1)).to(device)
            model.eval()
            for i in range(0, len(ob)):
                if "time" in model_name:
                    t = torch.ones((x[i].shape[0], 1)) * i
                    t = t.to(device)
                    res = model(x[i], ob[i], t)
                else:
                    res = model(x[i], ob[i])
                loss = loss + loss_fn(res, y[i])
            loss = loss / len(ob)
            val_batch_loss = torch.tensor(loss.item()).to(device)
            torch.distributed.all_reduce(val_batch_loss)
            val_loss = val_loss + val_batch_loss / world_size
        
        val_loss /= len(val_dataloader)

    return val_loss


def val_completer(val_dataloader,
        transformer,
        masker,
        model,
        model_name,
        loss_fn,
        local_rank,
        world_size):
    
    with torch.no_grad():
        device = local_rank
        val_loss = 0
        
        for _, data in enumerate(val_dataloader):
            x, y, _ = data
            if "DeepONet" in model_name:
                x, y, ob = data_preprocess_completer_DeepONet(masker, x, y, device)
            elif "GNOT" in model_name:
                x, y, ob = data_preprocess_completer_GNOT(masker, x, y, device)
            elif "LNO" in model_name:
                x, y, ob = data_preprocess_completer_LNO(masker, x, y, device)
            else:
                raise NotImplementedError("Invalid Completer Name")
            
            loss = torch.zeros((1)).to(device)
            model.eval()
            res = model(x, ob)
            loss = loss_fn(res, y)
            val_batch_loss = torch.tensor(loss.item()).to(device)
            torch.distributed.all_reduce(val_batch_loss)
            val_loss = val_loss + val_batch_loss / world_size
        
        val_loss /= len(val_dataloader)

    return val_loss


def test_propagator(test_dataloader,
        transformer,
        masker,
        poser,
        model,
        model_name,
        local_rank,
        world_size):

    with torch.no_grad():
        device = local_rank
        test_metric = 0
        test_side1_metric = 0
        test_side2_metric = 0
        
        for _, data in enumerate(test_dataloader):
            x, y, _ = data
            if "GNOT" in model_name:
                x, y, ob = data_preprocess_propagator_GNOT(masker, x, y, device)
            elif "LNO" in model_name:
                x, y, ob = data_preprocess_propagator_LNO(masker, x, y, device)
            elif "DeepONet" in model_name:
                x, y, ob = data_preprocess_propagator_DeepONet(masker, x, y, device)
            else:
                raise NotImplementedError("Invalid Propagator Name !")
            
            model.eval()
            if "time" in model_name:
                t = torch.zeros((x[0].shape[0], 1))
                t = t.to(device)
                res = model(x[0], ob[0], t)
            else:
                res = model(x[0], ob[0])
            
            for i in range(1, len(ob)):
                if "time" in model_name:
                    t = torch.ones((x[i].shape[0], 1)) * i
                    t = t.to(device)
                    res = model(x[i], ob[i], t)
                else:
                    res = model(x[i], ob[i])

            mask, _, _ = masker.get()
            res = torch.reshape(res, (res.shape[0], *tuple(mask.shape[1:]), res.shape[-1]))
            y = torch.reshape(y[-1], (y[-1].shape[0], *tuple(mask.shape[1:]), y[-1].shape[-1]))
            
            res = transformer.apply_y(res, inverse=True)
            y = transformer.apply_y(y, inverse=True)

            pos = poser.get()
            pos_idx = list(pos.to_sparse().indices().transpose(0,1).numpy())
            res = res.cpu().numpy()
            res = torch.tensor(np.array([res[(slice(None), *tuple(idx), slice(None))] for idx in pos_idx])).float().transpose(0, 1).to(device)
            y = y.cpu().numpy()
            y = torch.tensor(np.array([y[(slice(None), *tuple(idx), slice(None))] for idx in pos_idx])).float().transpose(0, 1).to(device)

            p = 2
            metric = RelLpLoss(p)
            test_batch_metric = torch.tensor(metric(res, y).item()).to(device)
            torch.distributed.all_reduce(test_batch_metric)
            test_metric = test_metric + test_batch_metric / world_size
            
            p = 2
            side1_metric = MpELoss(p)
            test_batch_side1_metric = torch.tensor(side1_metric(res, y).item()).to(device)
            torch.distributed.all_reduce(test_batch_side1_metric)
            test_side1_metric = test_side1_metric + test_batch_side1_metric / world_size
            
            p = 1  
            side2_metric = RelMpELoss(p)
            test_batch_side2_metric = torch.tensor(side2_metric(res, y).item()).to(device)
            torch.distributed.all_reduce(test_batch_side2_metric)
            test_side2_metric = test_side2_metric + test_batch_side2_metric / world_size
        
        test_metric /= len(test_dataloader)
        test_side1_metric /= len(test_dataloader)
        test_side2_metric /= len(test_dataloader)

    return test_metric, test_side1_metric, test_side2_metric


def test_completer(test_dataloader,
        transformer,
        masker,
        poser,
        model,
        model_name,
        local_rank,
        world_size):

    with torch.no_grad():
        device = local_rank
        test_metric = 0
        test_side1_metric = 0
        test_side2_metric = 0
        
        for _, data in enumerate(test_dataloader):
            x, y, _ = data
            if "DeepONet" in model_name:
                x, y, ob = data_preprocess_completer_DeepONet(masker, x, y, device)
            elif "GNOT" in model_name:
                x, y, ob = data_preprocess_completer_GNOT(masker, x, y, device)
            elif "LNO" in model_name:
                x, y, ob = data_preprocess_completer_LNO(masker, x, y, device)
            else:
                raise NotImplementedError("Invalid Completer Name")
            
            model.eval()
            res = model(x, ob)
            mask, vertex0, vertex1 = masker.get()
            res = torch.reshape(res, (res.shape[0], *tuple(vertex1 - vertex0), res.shape[-1]))
            y = torch.reshape(y, (y.shape[0], *tuple(vertex1 - vertex0), y.shape[-1]))
            
            res = transformer.apply_y(res, inverse=True)
            y = transformer.apply_y(y, inverse=True)

            pos = poser.get()
            pos_idx = list(pos.to_sparse().indices().transpose(0,1).numpy())
            res = res.cpu().numpy()
            res = torch.tensor(np.array([res[(slice(None), *tuple(idx), slice(None))] for idx in pos_idx])).float().transpose(0, 1).to(device)
            y = y.cpu().numpy()
            y = torch.tensor(np.array([y[(slice(None), *tuple(idx), slice(None))] for idx in pos_idx])).float().transpose(0, 1).to(device)

            p = 2
            metric = RelLpLoss(p)
            test_batch_metric = torch.tensor(metric(res, y).item()).to(device)
            torch.distributed.all_reduce(test_batch_metric)
            test_metric = test_metric + test_batch_metric / world_size
            
            p = 2
            side1_metric = MpELoss(p)
            test_batch_side1_metric = torch.tensor(side1_metric(res, y).item()).to(device)
            torch.distributed.all_reduce(test_batch_side1_metric)
            test_side1_metric = test_side1_metric + test_batch_side1_metric / world_size
            
            p = 1  
            side2_metric = RelMpELoss(p)
            test_batch_side2_metric = torch.tensor(side2_metric(res, y).item()).to(device)
            torch.distributed.all_reduce(test_batch_side2_metric)
            test_side2_metric = test_side2_metric + test_batch_side2_metric / world_size
        
        test_metric /= len(test_dataloader)
        test_side1_metric /= len(test_dataloader)
        test_side2_metric /= len(test_dataloader)

    return test_metric, test_side1_metric, test_side2_metric


def infer(infer_dataloader,
        transformer,
        masker_completer,
        masker_propagator,
        poser,
        completer,
        completer_name,
        propagator,
        propagator_name,
        local_rank,
        world_size):

    with torch.no_grad():
        device = local_rank
        infer_metric = 0
        infer_side1_metric = 0
        infer_side2_metric = 0
        
        for no, data in enumerate(infer_dataloader):
            source_x, source_y, _ = data
            if "DeepONet" in completer_name:
                x, y, ob = data_preprocess_completer_DeepONet(masker_completer, source_x.clone(), source_y.clone(), device)
            elif "GNOT" in completer_name:
                x, y, ob = data_preprocess_completer_GNOT(masker_completer, source_x.clone(), source_y.clone(), device)
            elif "LNO" in completer_name:
                x, y, ob = data_preprocess_completer_LNO(masker_completer, source_x.clone(), source_y.clone(), device)
            else:
                raise NotImplementedError("Invalid Completer Name !")
            
            completer.eval()
            res = completer(x, ob)
            res = torch.reshape(res, (res.shape[0], math.prod(tuple(res.shape[1:-1])), res.shape[-1]))
            x = torch.reshape(x, (x.shape[0], math.prod(tuple(x.shape[1:-1])), x.shape[-1]))
            y = torch.reshape(y, (y.shape[0], math.prod(tuple(y.shape[1:-1])), y.shape[-1]))
            
            res = torch.cat((x, res), dim=-1)
            
            if "GNOT" in propagator_name:
                x, y, ob = data_preprocess_propagator_GNOT(masker_propagator, source_x, source_y, device)
            elif "LNO" in propagator_name:
                x, y, ob = data_preprocess_propagator_LNO(masker_propagator, source_x, source_y, device)
            elif "DeepONet" in propagator_name:
                x, y, ob = data_preprocess_propagator_DeepONet(masker_propagator, source_x, source_y, device)
            else:
                raise NotImplementedError("Invalid Propagator Name !")

            propagator.eval()
            if "time" in propagator_name:
                t = torch.zeros((x[0].shape[0], 1))
                t = t.to(device)
                res = propagator(x[0], res, t)
            else:
                res = propagator(x[0], res)

            for i in range(1, len(ob)):
                ob[i] = torch.cat((x[i-1], res), dim=-1)
                if "time" in propagator_name:
                    t = torch.ones((x[i].shape[0], 1)) * i
                    t = t.to(device)
                    res = propagator(x[i], ob[i], t)
                else:
                    res = propagator(x[i], ob[i])

            mask, _, _ = masker_propagator.get()
            res = torch.reshape(res, (res.shape[0], *tuple(mask.shape[1:]), res.shape[-1]))
            y = torch.reshape(y[-1], (y[-1].shape[0], *tuple(mask.shape[1:]), y[-1].shape[-1]))
            
            res = transformer.apply_y(res, inverse=True)
            y = transformer.apply_y(y, inverse=True)
            
            np.save("../pics/res_npy/res{}".format(no), res[0].cpu().numpy())
            
            pos = poser.get()
            pos_idx = list(pos.to_sparse().indices().transpose(0,1).numpy())
            res = res.cpu().numpy()
            res = torch.tensor(np.array([res[(slice(None), *tuple(idx), slice(None))] for idx in pos_idx])).float().transpose(0, 1).to(device)
            y = y.cpu().numpy()
            y = torch.tensor(np.array([y[(slice(None), *tuple(idx), slice(None))] for idx in pos_idx])).float().transpose(0, 1).to(device)

            p = 2
            metric = RelLpLoss(p)
            infer_batch_metric = torch.tensor(metric(res, y).item()).to(device)
            torch.distributed.all_reduce(infer_batch_metric)
            infer_metric = infer_metric + infer_batch_metric / world_size
            
            p = 2
            side1_metric = MpELoss(p)
            infer_batch_side1_metric = torch.tensor(side1_metric(res, y).item()).to(device)
            torch.distributed.all_reduce(infer_batch_side1_metric)
            infer_side1_metric = infer_side1_metric + infer_batch_side1_metric / world_size
            
            p = 1  
            side2_metric = RelMpELoss(p)
            infer_batch_side2_metric = torch.tensor(side2_metric(res, y).item()).to(device)
            torch.distributed.all_reduce(infer_batch_side2_metric)
            infer_side2_metric = infer_side2_metric + infer_batch_side2_metric / world_size
        
        infer_metric /= len(infer_dataloader)
        infer_side1_metric /= len(infer_dataloader)
        infer_side2_metric /= len(infer_dataloader)

    return infer_metric, infer_side1_metric, infer_side2_metric
