import os
import numpy as np
from tqdm import tqdm
import sys
import contextlib
import torch
from torch.utils.data import Dataset, DataLoader
from utils.config import argparser
from utils.load_model import load_ckpt
from models.embedder import Embedder, BaseRN50, RoIEmbedder, RoIPosEmbedder
from dataset.video_align_dataset import VideoAlignmentDownstreamDataset
from dataset.video_align_dataset_bbox import VideoAlignmentBboxDownstreamDataset
from dataset.h2o_dataset import H2OVideoDownstreamDataset
from evaluation.kendalls_tau import kendalls_tau
from evaluation.frame_retrieval import frame_retrieval
from evaluation.event_completion import progression, progression_semi
from evaluation.classification import classification, classification_crossview, classification_fewshot


class Tee:
    def __init__(self, file):
        self.file = file
        self.stdout = sys.stdout

    def write(self, data):
        self.file.write(data)
        self.stdout.write(data)

    def flush(self):
        self.file.flush()
        self.stdout.flush()

@contextlib.contextmanager
def redirect_stdout_to_file(file_path):
    with open(file_path, "w") as file:
        tee = Tee(file)
        sys.stdout = tee
        try:
            yield
        finally:
            sys.stdout = tee.stdout


def prepare_data_loader(args, mode, batch_size=1024, num_workers=8, bbox=False):
    if bbox:
        dataset = VideoAlignmentBboxDownstreamDataset(args, mode)  #dataset = H2OVideoDownstreamDataset(args, mode)
    else:
        dataset = VideoAlignmentDownstreamDataset(args, mode)
    data_loader = DataLoader(
        dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=False,
        drop_last=False,
    )
    print(f'Data loader len {len(data_loader)}')
    return data_loader, dataset


def extract_embedding(mode, data_loader, model, save_path, device, label_all=False, label_only=False, imagenet=False, object_box=False):
    embeds_list = []
    labels_list = []
    for batch in tqdm(data_loader):
        if object_box:
            frame, frame_label, video_path, bbox = batch
        else:
            frame, frame_label, video_path = batch
        if not label_only:
            if imagenet:  # if use imagenet pretrained features, no context frames
                frame = frame[:, 1, ...]
            frame = frame.reshape(1, -1, *frame.shape[-3:])  # frame-(1, 64, 168, 168, 3)
            frame = frame.permute(0, 1, 4, 2, 3).float().to(device)   # (1, 64, 3, 168, 168)
            with torch.no_grad():
                if object_box:
                    # bbox = bbox.reshape(1, -1, 15)
                    bbox = bbox.unsqueeze(0).to(device)
                    embeds = model(frame, bbox)
                else:
                    embeds = model(frame)
            embeds = embeds.squeeze().cpu().numpy()
            embeds_list.append(embeds)
        labels_list.append(frame_label.numpy())

    save_path = os.path.join(save_path, 'eval')
    os.makedirs(save_path, exist_ok=True)
    name = 'labels_all.npy' if label_all else 'labels_new.npy'
    if not label_only:
        embeds = np.concatenate(embeds_list, axis=0)
        np.save(f'{save_path}/{mode}_embeds.npy', embeds)
    labels = np.concatenate(labels_list, axis=0)
    np.save(f'{save_path}/{mode}_{name}', labels)


def main():
    device = torch.device("cuda:0")
    args = argparser.parse_args()
    # args.dataset = 'pour_milk'
    batch_size = 128

    object_bbox = True if 'bbox' in args.task else False
    imagenet = True if 'imagenet' in args.task else False
    loader_train, dataset_train = prepare_data_loader(args, 'train', batch_size=batch_size,
                                                      num_workers=args.num_workers, bbox=object_bbox)
    loader_val, dataset_val = prepare_data_loader(args, args.eval_mode, batch_size=batch_size, num_workers=args.num_workers,
                                                  bbox=object_bbox)

    if object_bbox:
        model = RoIPosEmbedder(args).to(device)
    elif imagenet:
        model = BaseRN50().to(device)
    else:
        model = Embedder(args).to(device)

    model.eval()

    if args.ckpt != '':
        if args.ckpt.endswith('.ckpt'):
            load_ckpt(model, args.ckpt)
            save_path = args.ckpt.strip('.ckpt')
        else:
            save_path = args.ckpt
    else:
        save_path = f'./tmp/imagenet/{args.dataset}' if imagenet else f'./tmp/{args.dataset}'

    os.makedirs(save_path, exist_ok=True)
    log_file_path = f"{save_path}/log.txt"

    with redirect_stdout_to_file(log_file_path):
        if args.extract_embedding:
            extract_embedding('train', loader_train, model, save_path, device, args.label_all, args.label_only,
                              imagenet, object_bbox)
            extract_embedding('val', loader_val, model, save_path, device, args.label_all, args.label_only, imagenet,
                              object_bbox)
        fs_val_f1_1, fs_val_f1_2, val_f1, ego2exo_val_f1, exo2ego_val_f1, map_5, map_10, map_15, val_score, val_tau = 0, 0, 0, 0, 0, 0, 0, 0, 0, 0

        if '0' in args.eval_task:
            kendalls_tau(save_path, dataset_train.video_len_list, dataset_train.video_paths1, 'train', False)
            val_tau = kendalls_tau(save_path, dataset_val.video_len_list, dataset_val.video_paths1, 'val', False)

        if '1' in args.eval_task:
            _, _, val_f1 = classification(save_path, label_all=args.label_all, cls=True, few_shot=False)

        if '2' in args.eval_task:
            fs_val_f1_1, fs_val_f1_2 = classification(save_path, label_all=args.label_all, cls=False, few_shot=True,
                                                      detailed=True)

        if '3' in args.eval_task:
            map_5, map_10, map_15 = frame_retrieval(save_path, dataset_val.video_len_list, dataset_val.video_paths1)

        if '4' in args.eval_task:
            _, val_score = progression(save_path, dataset_train.video_len_list, dataset_val.video_len_list)

        if '5' in args.eval_task:
            ego2exo_val_f1, exo2ego_val_f1 = classification_crossview(save_path, dataset_train.video_ego_id,
                                                                      dataset_val.video_ego_id)

        if '6' in args.eval_task:
            classification_fewshot(save_path, dataset_train.video_len_list)

        if '7' in args.eval_task:
            # classification(save_path, label_all=args.label_all, cls=True, few_shot=False, modify_labels=True)
            # print('-' * 50)
            _, _, val_f1 = classification(save_path, label_all=args.label_all, cls=True, few_shot=False,
                                          modify_embeddings=True,
                                          train_video_len_list=dataset_train.video_len_list,
                                          val_video_len_list=dataset_val.video_len_list)

        if '8' in args.eval_task:
            _, val_score = progression(save_path, dataset_train.video_len_list, dataset_val.video_len_list,
                                       modify_embeddings=True)
            # progression(save_path, dataset_train.video_len_list, dataset_val.video_len_list, modify_labels=True)

        # if '9' in args.eval_task:
        #     progression_semi(save_path, dataset_train.video_len_list)

        if '9' in args.eval_task:
            frame_retrieval(save_path, dataset_val.video_len_list, dataset_val.video_paths1, cross_view=True)

        if 'all' in args.eval_task:
            results = '|'.join([f"{fs_val_f1_1:.2f}", f"{fs_val_f1_2:.2f}",
                                f"{val_f1 * 100:.2f}", f"{ego2exo_val_f1 * 100:.2f}",
                                f"{exo2ego_val_f1 * 100:.2f}", f"{map_5 * 100:.2f}",
                                f"{map_10 * 100:.2f}", f"{map_15 * 100:.2f}",
                                f"{val_score:.4f}", f"{val_tau:.4f}", f"{args.ckpt}"])
            print(f'|{results}|')


if __name__ == '__main__':
    main()




