import os
import sys
import datetime

import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch import nn

import wandb

from config import get_config
from DataLoader import get_loaders
from DataLoader import CLASSES
from architecture import get_model

from train_utils import eval_model
from train_utils import binary_metrics

os.environ["TOKENIZERS_PARALLELISM"] = "false"

torch.manual_seed(1312)
torch.backends.cudnn.enabled = False

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device: ", device)

# endregion

METRIC_TYPE = ["exact", "contains", "clip"]

if __name__ == "__main__":
    cnf = get_config(sys.argv)

    ROOT_FOLDER = os.path.join(cnf.wandb.log_dir, 'checkpoints')
    EXP_FOLDER = os.path.join(ROOT_FOLDER, cnf.exp_name)
    MODELS_FOLDER = os.path.join(EXP_FOLDER, 'models')
    PREDS_FOLDER = os.path.join(EXP_FOLDER, 'preds')
    cnf_dict = vars(cnf)
    # region ddp
    if cnf.DDP:
        cnf.local_rank = int(os.environ["LOCAL_RANK"])
        torch.cuda.set_device(cnf.local_rank)
        cnf.is_master = cnf.local_rank == 0
        cnf.device = torch.cuda.device(cnf.local_rank)
        cnf.world_size = int(os.environ['WORLD_SIZE'])
        os.environ['NCCL_BLOCKING_WAIT'] = '0'
        dist.init_process_group(backend='nccl', timeout=datetime.timedelta(seconds=7200))
        df_lst = ["" for _ in range(cnf.world_size)]
    else:
        os.environ['WORLD_SIZE'] = "1"
        df_lst = [""]
        cnf.local_rank = 0
        cnf.is_master = True
    # endregion

    # region dir set_up
    if cnf.is_master:
        if not os.path.exists(MODELS_FOLDER):
            os.makedirs(MODELS_FOLDER)
        if not os.path.exists(PREDS_FOLDER):
            os.makedirs(PREDS_FOLDER)
    # endregion

    _, test_loader = get_loaders(cnf)
    model, processor = get_model(cnf.model)
    cnf.exp_name = 'HumanEvalSamples_' + cnf.exp_name
    if cnf.wandb.log and cnf.is_master:
        wand_run = wandb.init(project='ffVQA', notes='', config=cnf_dict, name=cnf.exp_name)
        sample_table = wandb.Table(
            columns=['sample', 'ground_truth', 'prompts', 'rationale'])
    if cnf.DDP:
        #dist.barrier()
        model = model.to(device)
        model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
        model = DDP(
            model,
            device_ids=[cnf.local_rank],
            output_device=cnf.local_rank
        )
    synonym = 'manipulated'
    followup_template = cnf.prompts.q2.format(synonym)
    prompt_template = cnf.prompts.q1.template.format(synonym)
    vqa_template = None
    test_table = eval_model(
        loader=test_loader,
        model=model,
        prompt=prompt_template,
        followup=followup_template,
        vqa=vqa_template,
        processor=processor,
        max_iter=cnf.training.iterations,
        num_classes=len(CLASSES),
        master=cnf.is_master
    )

    test_table[CLASSES] = test_table['ground_truth'].tolist()
    finegrained_gt = test_table[CLASSES].to_numpy()
    msk = finegrained_gt[:, 0] != 1
    test_table = test_table[msk]
    sample = test_table.sample(n=100, random_state=1312)
    for idx, sam in sample.iterrows():
        gt = 'The areas that are ' + synonym + ' are: ' + ' '.join(
            CLASSES[i] for i in range(len(CLASSES)) if sam[CLASSES[i]] == 1
        )
        row = [wandb.Image(sam["video_id"]), gt, followup_template, sam["rationale"]]
        sample_table.add_data(*row)

    if cnf.wandb.log and cnf.is_master:
        wand_run.log({"sample": sample_table})

    if cnf.DDP:
        dist.destroy_process_group()
