import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import random
from torch.utils.data import DataLoader
import argparse
import os
import einops
import math

from dataset import TextDataset, AudioImageDataset, ImageNet_AudioSet_Dataset, SycDataset
from misc.logger import LOGGER, TB_LOGGER, AverageMeter, RunningMeter, add_log_to_file
from metric import Retrieval_metrics, MiniRetrieval_metrics, MRR
from util import AvgMeter, get_lr
from ProjModel import Model
from loss import get_CLIP_loss, get_item_L2_loss

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

def build_dataset(text_dataset_list, text_support, TS_temperature):
    text_dataset = TextDataset(text_dataset_list)
    flickr_dataset = AudioImageDataset('Flickr')
    AVE_dataset = AudioImageDataset('AVE')
    return text_dataset, flickr_dataset, AVE_dataset


def validate(model, val_dataloader, save_path=None):
    model.eval()

    image_emb_list, audio_emb_list = [], []
    for ib, batch in enumerate(val_dataloader):
        audio_emb = batch[0].cuda()
        image_emb = batch[1].cuda()

        image_emb, audio_emb = model(image_emb, audio_emb)
        image_emb_list.append(image_emb)
        audio_emb_list.append(audio_emb)

    image_emb = torch.cat(image_emb_list, dim=0)
    audio_emb = torch.cat(audio_emb_list, dim=0)
    a2i_sim = torch.einsum('nb,tb->nt', audio_emb, image_emb)
    i2a_sim = a2i_sim.T
    if save_path:
        torch.save(a2i_sim, save_path + '_a2i.pt')
        torch.save(i2a_sim, save_path + '_i2a.pt')

    a2i_metrics = Retrieval_metrics(a2i_sim)
    i2a_metrics = Retrieval_metrics(i2a_sim)

    ave_mrr = (a2i_metrics['mrr'] + i2a_metrics['mrr']) / 2

    return ave_mrr, a2i_metrics, i2a_metrics

def get_uniform_ball_noise(input_shape, radius=0.1):
    uniform_noise_ball = torch.randn(input_shape).cuda()  # normal distribution
    uniform_noise_sphere = torch.nn.functional.normalize(uniform_noise_ball, dim=-1)
    u = torch.rand(input_shape[0]).cuda()  # unified distribution
    u = u ** (1. / input_shape[1])
    uniform_noise_ball = (uniform_noise_sphere.T * u * radius).T
    return uniform_noise_ball

def noise_injection(x, variance=0.001, modality_offset=None, uniform_noise=False):
    device = x.device
    if variance == 0.0:
        return x
    std = math.sqrt(variance)
    if uniform_noise:
        x = x + get_uniform_ball_noise(x.shape, radius=std)
    else:
        x = x + (torch.randn(x.shape).to(device) * std)  # todo by some conventions multivraiance noise should be devided by sqrt of dim
    if modality_offset is not None:
        x = x + modality_offset
    return torch.nn.functional.normalize(x, dim=-1)

# Define the training function
def train_one_epoch(model, train_loader, optimizer, lr_scheduler, epoch, cfg):
    loss_meter = AvgMeter()
    show_step = len(train_loader) // 10

    for i, (CLIP_embs, CLAP_embs, image_embs, audio_embs) in enumerate(train_loader):

        CLIP_embs = CLIP_embs.cuda()
        CLAP_embs = CLAP_embs.cuda()
        if cfg.Text_noise:
            CLIP_embs = noise_injection(CLIP_embs, variance=cfg.variance, modality_offset=cfg.modality_offset)
            CLAP_embs = noise_injection(CLAP_embs, variance=cfg.variance, modality_offset=cfg.modality_offset)
        CLIP_embs, CLAP_embs = model(CLIP_embs, CLAP_embs)

        image_embs = image_embs.cuda()
        audio_embs = audio_embs.cuda()
        if cfg.AV_noise:
            image_embs = noise_injection(image_embs, variance=cfg.variance, modality_offset=cfg.modality_offset)
            audio_embs = noise_injection(audio_embs, variance=cfg.variance, modality_offset=cfg.modality_offset)
        image_embs, audio_embs = model(image_embs, audio_embs)

        loss_dict = {}

        align_loss = torch.tensor(0)

        item = 0
        loss = 0
        if cfg.TT_CLIP_loss:
            TT_CLIP_loss = get_CLIP_loss(CLIP_embs, CLAP_embs, temperature=cfg.temperature)
            loss_dict['TT_contra_loss'] = TT_CLIP_loss.data.item()
            loss = loss + TT_CLIP_loss
            item = item + 1
        if cfg.AV_CLIP_loss:
            AV_CLIP_loss = get_CLIP_loss(image_embs, audio_embs, temperature=cfg.temperature)
            loss_dict['AV_contra_loss'] = AV_CLIP_loss.data.item()
            loss = loss + AV_CLIP_loss
            item = item + 1
        if item == 0:
            loss_dict['contra_loss'] = 0
        else:
            loss = loss / item
            loss_dict['contra_loss'] = loss

        if cfg.modality_align:
            if cfg.align_loss == 'item_L2':
                CLIP_align_loss = get_item_L2_loss(CLIP_embs, image_embs)
                CLAP_align_loss = get_item_L2_loss(CLAP_embs, audio_embs)
                align_loss = (CLIP_align_loss + CLAP_align_loss) / 2
                loss = cfg.align_loss_factor * align_loss + loss

            loss_dict['CLIP_align_loss'] = CLIP_align_loss.data.item()
            loss_dict['CLAP_align_loss'] = CLAP_align_loss.data.item()
            loss_dict['align_loss'] = align_loss.data.item()

        loss_dict['train_loss'] = loss.data.item()
        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        TB_LOGGER.log_scalar_dict({f'train_loss/{k}': v for k, v in loss_dict.items()})
        TB_LOGGER.step()

        if i % show_step == 0:
            LOGGER.info(f'epoch:{epoch+1}/{cfg.epoch}, batch:{i}/{len(train_loader)}, Loss:{loss.item():.4f}, align_loss:{align_loss.item():.4f}')

        loss_meter.update(loss)

    if lr_scheduler:
        lr_scheduler.step()
        cur_lr = lr_scheduler.get_last_lr()
        print(cur_lr)
        LOGGER.info(f'epoch:{epoch+1}/{cfg.epoch}, learning rate:{cur_lr}')
        TB_LOGGER.log_scalar_dict({f'train_loss/learning_rate': cur_lr[-1]})
    else:
        TB_LOGGER.log_scalar_dict({f'train_loss/learning_rate': cfg.lr})
    return loss_meter

def load_evaluate(model, model_path, AVE_val_loader, flickr_val_loader):
    model.load_state_dict(torch.load(os.path.join(model_path, 'best.pt')))
    model.eval()
    with torch.no_grad():
        AVE_ave_mrr, AVE_a2i_metrics, AVE_i2a_metrics = validate(model, AVE_val_loader, save_path = os.path.join(model_path, 'AVE'))
        flickr_ave_mrr, flickr_a2i_metrics, flickr_i2a_metrics = validate(model, flickr_val_loader, save_path = os.path.join(model_path, 'Flickr'))

        val_dict = {'ave_AVE_mrr': AVE_ave_mrr, 'ave_flickr_mrr': flickr_ave_mrr}

        val_dict['cAVE_a2i_mrr'] = AVE_a2i_metrics['mrr']
        val_dict['cAVE_i2a_mrr'] = AVE_i2a_metrics['mrr']
        val_dict['flickr_a2i_mrr'] = flickr_a2i_metrics['mrr']
        val_dict['flickr_i2a_mrr'] = flickr_i2a_metrics['mrr']
        LOGGER.info(f'AVE audio2image: {AVE_a2i_metrics}')
        LOGGER.info(f'AVE image2audio: {AVE_i2a_metrics}')
        LOGGER.info(f'Flickr audio2image: {flickr_a2i_metrics}')
        LOGGER.info(f'Flickr image2audio: {flickr_i2a_metrics}')
    return 0

def main(cfg):
    setup_seed(cfg.seed)
    train_dataset, flickr_val_dataset, AVE_val_dataset = build_dataset(cfg.train_datasets, cfg.Text_Support, cfg.TS_temperature)
    if cfg.modality_align and cfg.AVS_dataset:
        print("corresponding items aligning!")
        imagenet_audioset_dataset = TextDataset(cfg.AVS_dataset)
        print('imagenet_audioset_embedding_num', len(imagenet_audioset_dataset))
    else:
        imagenet_audioset_dataset = ImageNet_AudioSet_Dataset(len(train_dataset))
        print('imagenet_audioset_embedding_num', len(imagenet_audioset_dataset))

    syc_dataset = SycDataset(train_dataset, imagenet_audioset_dataset)
    syc_loader = DataLoader(syc_dataset, batch_size=cfg.batch_size, shuffle=True)
    flickr_val_loader = DataLoader(flickr_val_dataset, batch_size=cfg.batch_size, shuffle=False)
    AVE_val_loader = DataLoader(AVE_val_dataset, batch_size=cfg.batch_size, shuffle=False)

    TB_LOGGER.create(os.path.join(cfg.save_path, 'logs'))
    add_log_to_file(os.path.join(cfg.save_path, 'logs', 'log.txt'))
    # Create the model, optimizer, and loss function
    model = Model(cfg)
    model.cuda()

    print(model)
    print(cfg)

    if not cfg.Text_Support:
        LOGGER.info(f'Train texts:{len(train_dataset)}, AVE:{len(AVE_val_dataset)}, Flickr:{len(flickr_val_dataset)}')
    else:
        LOGGER.info(f'Train texts:{len(train_dataset)}, TS{cfg.TS_temperature}-Flickr:{len(flickr_val_dataset)}')

    if cfg.evaluate:
        print('evaluating model!!!')
        load_evaluate(model, model_path=cfg.save_path, AVE_val_loader=AVE_val_loader,
                      flickr_val_loader=flickr_val_loader)
    else:
        param_dicts = [
            {"params": [p for n, p in model.named_parameters() if "domain" in n and p.requires_grad],
             "lr": cfg.discr_lr_factor * cfg.lr},
            {"params": [p for n, p in model.named_parameters() if "domain" not in n and p.requires_grad],
             "lr": cfg.lr}
        ]

        optimizer = torch.optim.AdamW(param_dicts, lr=cfg.lr, weight_decay=cfg.weight_decay)
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, cfg.epoch)
        best_val_dict = {}
        best_AVE_a2i_metrics = {}
        best_AVE_i2a_metrics = {}
        best_flickr_a2i_metrics = {}
        best_flickr_i2a_metrics = {}
        best_ave_map = 0
        best_epoch = 0

        for epoch in range(cfg.epoch):
            LOGGER.info(f'Epoch: {epoch+1}/{cfg.epoch}')
            model.train()
            train_loss = train_one_epoch(model, syc_loader, optimizer, lr_scheduler, epoch, cfg)
            with torch.no_grad():
                AVE_ave_mrr, AVE_a2i_metrics, AVE_i2a_metrics = validate(model, AVE_val_loader)
                flickr_ave_mrr, flickr_a2i_metrics, flickr_i2a_metrics = validate(model, flickr_val_loader)

                val_dict = {'ave_AVE_mrr': AVE_ave_mrr, 'ave_flickr_mrr': flickr_ave_mrr}

                val_dict['cAVE_a2i_mrr'] = AVE_a2i_metrics['mrr']
                val_dict['cAVE_i2a_mrr'] = AVE_i2a_metrics['mrr']
                val_dict['flickr_a2i_mrr'] = flickr_a2i_metrics['mrr']
                val_dict['flickr_i2a_mrr'] = flickr_i2a_metrics['mrr']

                TB_LOGGER.log_scalar_dict({f'valid/{k}': v for k, v in val_dict.items()})

                LOGGER.info(cfg)
                LOGGER.info(f'Best Epoch: {best_epoch}')
                LOGGER.info(f'Best metrics: {best_val_dict}')
                LOGGER.info(f'Best AVE audio2image: {best_AVE_a2i_metrics}')
                LOGGER.info(f'Best AVE image2audio: {best_AVE_i2a_metrics}')
                LOGGER.info(f'Best Flickr audio2image: {best_flickr_a2i_metrics}')
                LOGGER.info(f'Best Flickr image2audio: {best_flickr_i2a_metrics}')

                LOGGER.info(f'epoch:{epoch + 1}/{cfg.epoch}, ave_AVE_mrr: {val_dict["ave_AVE_mrr"]}, ave_flickr_mrr: {val_dict["ave_flickr_mrr"]}')
                LOGGER.info(f'AVE audio2image: {AVE_a2i_metrics}')
                LOGGER.info(f'AVE image2audio: {AVE_i2a_metrics}')
                LOGGER.info(f'Flickr audio2image: {flickr_a2i_metrics}')
                LOGGER.info(f'Flickr image2audio: {flickr_i2a_metrics}')

            ave_map = (flickr_ave_mrr + AVE_ave_mrr)
            if ave_map > best_ave_map:
                best_ave_map = ave_map
                best_epoch = epoch+1
                best_val_dict = val_dict
                best_AVE_a2i_metrics = AVE_a2i_metrics
                best_AVE_i2a_metrics = AVE_i2a_metrics
                best_flickr_a2i_metrics = flickr_a2i_metrics
                best_flickr_i2a_metrics = flickr_i2a_metrics
                torch.save(model.state_dict(), os.path.join(cfg.save_path, 'best.pt'))
                print("Saved Best Model!")

def load_parser():
    parser = argparse.ArgumentParser()

    parser.add_argument("--seed", default=44)
    parser.add_argument("--train_datasets", default=['COCO', 'CC1M', 'AudioCap', 'Clotho', 'MSRVTT', 'MAD'])
    parser.add_argument("--batch_size", default=10240)
    parser.add_argument("--epoch", default=36)
    parser.add_argument("--lr", default=1e-3)
    parser.add_argument("--resMLP", default=False)
    parser.add_argument("--weight_decay", default=1e-5)
    parser.add_argument("--temperature", default=0.01)

    parser.add_argument("--res_expansion", default=2)
    parser.add_argument("--mlp_num", default=1)

    parser.add_argument("--TT_CLIP_loss", default=True)
    parser.add_argument("--AV_CLIP_loss", default=True)

    parser.add_argument("--modality_align", default=True)
    parser.add_argument("--AVS_dataset", default=['AVS100'])
    parser.add_argument("--align_loss", default='item_L2')
    parser.add_argument("--align_loss_factor", default=0.1)
    parser.add_argument("--Text_noise", default=True)
    parser.add_argument("--AV_noise", default=True)
    parser.add_argument("--variance", default=0.004)

    parser.add_argument("--evaluate", default=False)
    parser.add_argument('--save_path', default=os.path.join('./output', 'C-MCR'))

    return parser

if __name__ == '__main__':
    parser = load_parser()
    args = parser.parse_args()
    main(args)