from tqdm import tqdm

import torch
import torch.nn.functional as F
import torch.nn as nn

import clip
from collections import defaultdict
from datasets.imagenet import imagenet_classes


def image_opt(feat, init_classifier, plabel, label, lr=10, iter=2000, tau_i=0.04, alpha=0.6):
    ins, dim = feat.shape
    val, idx = torch.max(plabel, dim=1)
    mask = val > alpha
    plabel[mask, :] = 0
    plabel[mask, idx[mask]] = 1
    base = feat.T @ plabel
    classifier = init_classifier.clone()
    pre_norm = float('inf')
    for i in range(0, iter):
        prob = F.softmax(feat @ classifier / tau_i, dim=1)
        grad = feat.T @ prob - base
        temp = torch.norm(grad)
        if temp > pre_norm:
            lr /= 2.
        pre_norm = temp
        classifier -= (lr / (ins * tau_i)) * grad
        classifier = F.normalize(classifier, dim=0)
    return classifier


def sinkhorn(M, tau_t=0.01, gamma=0, iter=20):
    row, col = M.shape
    P = F.softmax(M / tau_t, dim=1)
    P /= row
    if gamma > 0:
        q = torch.sum(P, dim=0, keepdim=True)
        q = q**gamma
        q /= torch.sum(q)
    for it in range(0, iter):
        # total weight per column must be 1/col or q_j
        P /= torch.sum(P, dim=0, keepdim=True)
        if gamma > 0:
            P *= q
        else:
            P /= col
        # total weight per row must be 1/row
        P /= torch.sum(P, dim=1, keepdim=True)
        P /= row
    P *= row  # keep each row sum to 1 as the pseudo label
    return P


def softmax_entropy(x):
    return -(x.softmax(1) * x.log_softmax(1)).sum(1)

def cls_acc(output, target, topk=1):
    pred = output.topk(topk, 1, True, True)[1].t()

    correct = pred.eq(target.view(1, -1).expand_as(pred))
    acc = float(correct[: topk].reshape(-1).float().sum(0, keepdim=True).cpu().numpy())
    acc = 100 * acc / target.shape[0]
    return acc


def clip_classifier(classnames, template, clip_model):
    with torch.no_grad():
        clip_weights = []
        clip_weights_tol = []
        for classname in tqdm(classnames):
            # Tokenize the prompts
            classname = classname.replace('_', ' ')
            texts = [t.format(classname) for t in template]
            texts = clip.tokenize(texts).cuda()
            # prompt ensemble for ImageNet
            class_embeddings = clip_model.encode_text(texts)
            class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
            clip_weights_tol.append(class_embeddings)
            class_embedding = class_embeddings.mean(dim=0)
            class_embedding /= class_embedding.norm()
            clip_weights.append(class_embedding)

        clip_weights = torch.stack(clip_weights, dim=1).cuda()
        clip_weights_tol = torch.stack(clip_weights_tol, dim=1).cuda()
    return clip_weights, clip_weights_tol

def pre_load_features(cfg, split, clip_model, loader, norm=True):
    features, labels = [], []

    with torch.no_grad():
        for i, (images, target) in enumerate(tqdm(loader)):
            images, target = images.cuda(), target.cuda()
            image_features = clip_model.encode_image(images)
            if norm:
                image_features /= image_features.norm(dim=-1, keepdim=True)
            features.append(image_features)
            labels.append(target)

    features, labels = torch.cat(features), torch.cat(labels)  
    return features, labels

def build_cache_model(cfg, clip_model, train_loader_cache):  
    cache_keys = []
    cache_values = []

    with torch.no_grad():
        # Data augmentation for the cache model
        for augment_idx in range(cfg['augment_epoch']):
            train_features = []

            print('Augment Epoch: {:} / {:}'.format(augment_idx, cfg['augment_epoch']))
            for i, (images, target) in enumerate(tqdm(train_loader_cache)):
                images = images.cuda()
                image_features = clip_model.encode_image(images)
                train_features.append(image_features)
                if augment_idx == 0:
                    target = target.cuda()
                    cache_values.append(target)
            cache_keys.append(torch.cat(train_features, dim=0).unsqueeze(0))

    cache_keys = torch.cat(cache_keys, dim=0).mean(dim=0)
    cache_keys /= cache_keys.norm(dim=-1, keepdim=True)
    cache_keys = cache_keys.permute(1, 0)
    cache_values = F.one_hot(torch.cat(cache_values, dim=0)).half()

    return cache_keys, cache_values

def gpt_llm_fea(classnames, gpt_prompts, clip_model, template):
    with torch.no_grad():
        # text_embd = []
        # prompt_list = [len(gpt_prompts[cls.replace('_',' ').replace('-', ' ')]) for cls in classnames]
        prompt_list = [len(gpt_prompts[cls.replace('_',' ')]) for cls in classnames]
        if max(prompt_list) == min(prompt_list):
            print('text descriptions are not equal, and minimal is :{}'.format(min(prompt_list)))
        else:
            print('text descriptions for each class is : {}'.format(min(prompt_list)))
        num_des = min(prompt_list)
        # num_des = 30 # for imagenet
        prompt_fea = defaultdict()

        for classname in tqdm(classnames):
            classname = classname.replace('_', ' ')
            text = list()
            
            # text = [t.format(classname) for t in template]
            prompt_text = gpt_prompts[classname][:num_des]
            for item in prompt_text:
                # text.extend([item.replace('.', '. ' + t.format(classname)) for t in template])
                # text.extend([t.format(classname) + ' ' + item for t in template])
                text.extend([item])
            text.extend([t.format(classname) for t in template])

            texts = clip.tokenize(text, truncate=True).cuda()
            # texts = clip.tokenize(gpt_prompts[classname][:num_des]).cuda()
            # prompt ensemble for ImageNet
            class_embeddings = clip_model.encode_text(texts)
            class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)

            # class_embedding = class_embeddings.mean(dim=0)
            # class_embedding /= class_embedding.norm()
            prompt_fea[classname] = class_embeddings
        return prompt_fea


def search_hp(cfg, cache_keys, cache_values, features, labels, clip_weights, adapter=None):

    if cfg['search_hp'] == True:
    
        beta_list = [i * (cfg['search_scale'][0] - 0.1) / cfg['search_step'][0] + 0.1 for i in range(cfg['search_step'][0])]
        alpha_list = [i * (cfg['search_scale'][1] - 0.1) / cfg['search_step'][1] + 0.1 for i in range(cfg['search_step'][1])]

        best_acc = 0
        best_beta, best_alpha = 0, 0

        for beta in beta_list:
            for alpha in alpha_list:
                if adapter:
                    affinity = adapter(features)
                else:
                    affinity = features @ cache_keys

                cache_logits = ((-1) * (beta - beta * affinity)).exp() @ cache_values
                clip_logits = 100. * features @ clip_weights
                tip_logits = clip_logits + cache_logits * alpha
                acc = cls_acc(tip_logits, labels)
            
                if acc > best_acc:
                    print("New best setting, beta: {:.2f}, alpha: {:.2f}; accuracy: {:.2f}".format(beta, alpha, acc))
                    best_acc = acc
                    best_beta = beta
                    best_alpha = alpha

        print("\nAfter searching, the best accuarcy: {:.2f}.\n".format(best_acc))

    return best_beta, best_alpha

def logits_selection(logits):
    batch_entropy = softmax_entropy(logits)
    selected_idx = torch.argsort(batch_entropy, descending=False)[:int(batch_entropy.size()[0] * 0.1)]
    output = logits[selected_idx]
    # image_features = image_features[selected_idx].mean(0).unsqueeze(0)
    logits = output.mean(0).unsqueeze(0)
    return logits, selected_idx
