'''
This code generates the prediction on one instance. 
Both the ground truth and the prediction are saved in a .pt file.
'''
import os
from unittest import result
import yaml
from collections import OrderedDict
from argparse import ArgumentParser
from argparse import Namespace 
import numpy as np
import torch
from torch.utils.data import DataLoader

from models.fno import fno_pretrain

from tqdm import tqdm
from pdb import set_trace as bp

from utils.data_utils import get_data_loader
from utils.loss_utils import LossMSE

import torch.distributed as dist
from utils.masking_generator import MaskingGenerator

masking_generator = MaskingGenerator((64, 64), 0.1)


@torch.no_grad()
def get_pred(args):
    with open(args.config, 'r') as stream:
        config = yaml.load(stream, yaml.FullLoader)
    save_dir = 'save_dir'
    os.makedirs(save_dir, exist_ok=True)
    save_path = os.path.join(save_dir,'fno-prediction.pt')
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    data_param = Namespace(**config['data'])
    if not hasattr(data_param, 'n_demos'):
        data_param.n_demos = 0
    if 'batchsize' in config['test']:
        data_param.local_valid_batch_size = config['test']['batchsize']
    else:
        data_param.local_valid_batch_size = 1
    if args.num_demos is not None and args.num_demos != 0:
        data_param.datapath = data_param.datapath.replace("test", "train")
        data_param.local_valid_batch_size = args.num_demos
    dataloader, dataset, sampler = get_data_loader(data_param, data_param.datapath, dist.is_initialized(), train=False, pack=data_param.pack_data)

    model_param = Namespace(**config['model'])
    model_param.n_demos = data_param.n_demos
    model = fno_pretrain(model_param).to(device)

    if args.ckpt_path:
        checkpoint = torch.load(args.ckpt_path)
        try:
            model.load_state_dict(checkpoint['model_state'])
        except:
            new_state_dict = OrderedDict()
            for key, val in checkpoint['model_state'].items():
                name = key
                if "module" in name:
                    name = name[7:]
                new_state_dict[name] = val
            state = model.state_dict()
            # 1. filter out unnecessary keys
            pretrained_dict = {k: v for k, v in new_state_dict.items() if k in state and state[k].size() == new_state_dict[k].size()}
            # 2. overwrite entries in the existing state dict
            state.update(pretrained_dict)
            # 3. load the new state dict
            message = model.load_state_dict(state)
            # self.model.load_state_dict(new_state_dict)
            unload_keys = [k for k in new_state_dict.keys() if k not in pretrained_dict]
            if len(unload_keys) > 0:
                import warnings
                warnings.warn("Warning: unload keys during restoring checkpoint: %s"%(str(unload_keys)))

    loss_param = Namespace(**config['test'])
    loss_param.device = device
    mseloss = LossMSE(loss_param, None)

    # metric
    # lploss = LpLoss(size_average=True)
    model.eval()
    truth_list = []
    pred_list = []
    source_list = []
    results = []
    mask_list = []
    pbar = tqdm(dataloader, total=len(dataloader))
    # if args.num_demos is not None and args.num_demos != 0:
    #     pbar = tqdm([next(iter(dataloader))], total=1)
    # for u, a_in in dataloader:
    print(len(next(iter(dataloader))))
    for inputs, targets in [next(iter(dataloader))]:
        inputs, targets = inputs.to(device), targets.to(device)
        masks = torch.from_numpy(np.stack([masking_generator().reshape(64, 64)for _ in range(inputs.shape[0])], axis=0)).to(device) 
        if args.num_demos is None or args.num_demos == 0:
            u = model(inputs, masks)
        else:
            u = model(inputs, masks)
            # model.target = u # for debugging purpose
            # out = model.forward_icl(a_in, a_in_demos, u_demos, use_tqdm=args.tqdm)
            # out = model.forward_icl(a_in, a_in_demos, u_demos, use_tqdm=False)
        # data_loss = lploss(out, u)

        data_loss = mseloss.data(inputs, u, targets)
        results.append(data_loss.item())
        # print(data_loss.item())
        truth_list.append(targets.cpu())
        pred_list.append(u.cpu())
        source_list.append(inputs.cpu())
        mask_list.append(masks.cpu())
    print(np.mean(results))
    truth_arr = torch.cat(truth_list, dim=0)
    pred_arr = torch.cat(pred_list, dim=0)
    src_arr = torch.cat(source_list, dim=0)
    mask_arr = torch.cat(mask_list, dim=0)
    torch.save({
        'truth': truth_arr,
        'pred': pred_arr,
        'src': src_arr,
        'mask': mask_arr
    }, save_path)


if __name__ == "__main__":
    torch.backends.cudnn.benchmark = True
    parser = ArgumentParser()
    parser.add_argument('--config', type=str, default='config/inference_helmholtz.yaml')
    parser.add_argument('--ckpt_path', type=str, default='./ckpt.tar')
    parser.add_argument('--num_demos', type=int, default=None)
    parser.add_argument('--tqdm', action='store_true', default=False, help='Turn on the tqdm')
    args = parser.parse_args()
    get_pred(args)
