import os
import random
import argparse
import yaml
from tqdm import tqdm

import torch
import torch.nn.functional as F
import torch.nn as nn
import torchvision.transforms as transforms

from datasets.imagenet import ImageNet
from datasets import build_dataset
from datasets.utils import build_data_loader
import clip
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
from utils import *
from torch.autograd import Variable
import numpy as np
import json

from sklearn.covariance import LedoitWolf, OAS, GraphicalLassoCV, GraphicalLasso
from tqdm import tqdm


_tokenizer = _Tokenizer()
train_tranform = transforms.Compose([
        transforms.RandomResizedCrop(size=224, scale=(0.5, 1), interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
    ])


def get_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', default='configs/ood/imagenet.yaml', help='settings of Tip-Adapter in yaml format')
    args = parser.parse_args()
    return args        


def pp_estimate_eigen(logits, labels):
    probs = F.softmax(logits, dim=-1)
    _, K = probs.size()
    P = torch.zeros(K, K)
    for k in range(K):
        indices = (labels == k).nonzero(as_tuple=True)[0]
        class_probs = probs[indices]
        class_average = class_probs.mean(dim=0) if len(indices) > 0 else torch.zeros(K)
        P[k] = class_average

    P = P.t()
    pp = power_iteration(P)
    return pp, logits

def compute_confidence(logits, T):
    scale_logits = logits / T
    probability = F.softmax(scale_logits, dim=-1)
    confidence, _ = torch.max(probability, dim=-1)
    confidence = confidence.mean().item()

    return confidence

def search_T(logits, tao, zero_conf):
    conf = compute_confidence(logits, tao)

    if abs(zero_conf  - conf) / zero_conf < 0.1:
        return tao

    if zero_conf >= conf :
        tao /= 1.1
    elif zero_conf < conf:
        tao *= 1.1
    return search_T(logits, tao, zero_conf) 

def power_iteration(A, tol=1e-3, max_iter=100):
    n = A.shape[0]
    x = torch.full((n,), 1/n, dtype=A.dtype, device=A.device)   

    for _ in range(max_iter):
        x_next = torch.mv(A, x)
        x_next = x_next / torch.norm(x_next, p=1)  
        x_next = torch.clamp(x_next, min=0)  

        if torch.norm(x_next - x, p=1) < tol:
            break

        x = x_next

    return x

def run(cfg, clip_weights, clip_model, llm_fea):  
    
    # Parameter Estimation.
    with torch.no_grad():      
        cls_num = clip_weights.shape[-1]
        fea_dim = clip_weights.shape[0]

        clip_weights_ave = llm_fea.mean(dim=1)
        llm_weights_ave = clip_weights_ave / clip_weights_ave.norm(dim=-1, keepdim=True)
        
        clip_logits = 100. * test_features @ clip_weights
        llm_ave_logits = 100. * test_features @ llm_weights_ave.T
        clip_acc = cls_acc(clip_logits, test_labels)

        llm_weights = llm_fea.reshape(-1, fea_dim)
        llm_weights /= llm_weights.norm(dim=-1, keepdim=True)

        print('clip_zero:', clip_acc)

        lamda = 0.9
        used_samples = torch.cat((test_features, llm_weights.reshape(-1, fea_dim)), dim=0)
        cov_inv = torch.linalg.inv((1 - lamda) * used_samples.T.cov() * torch.eye(fea_dim).cuda() + lamda * torch.eye(fea_dim).cuda())
        x_part = (test_features @ cov_inv @ test_features.T ).diag() / 2
        mu_part = (llm_weights_ave @ cov_inv @ llm_weights_ave.T).diag() / 2
        crs_part = llm_weights_ave @ cov_inv @ test_features.T
        lda_logits = (crs_part - mu_part[:,None] - x_part[None,:]).T
        print('lda logits', cls_acc(lda_logits, test_labels))  

        #Step1: calibation + ensembel:
        llm_confidence = compute_confidence(llm_ave_logits, 1)
        # print('llm confidence', llm_confidence)
        
        opt_T = search_T(lda_logits, 1, llm_confidence)
        # calibated_confidence = compute_confidence(lda_logits, opt_T)
        # print('lda confidence', calibated_confidence)

        ensemble_logits = llm_ave_logits + lda_logits / opt_T
        print('ensemble_acc', cls_acc(ensemble_logits, test_labels))

        #Step2: debising
        used_debias_logits = ensemble_logits.clone()
        used_debias_labels = used_debias_logits.argmax(dim=-1)
        n = 0
        pp_orig = torch.zeros(cls_num)
        while n <= 10:
            values, _ = used_debias_logits.topk(2, dim=-1)
            diff = torch.abs(values[:,0] - values[:,1])
            epsilon = 1 * 1 / cls_num
            index = torch.where(diff > epsilon)[0]
            used_debias_logits = used_debias_logits[index]
            used_debias_labels = used_debias_logits.argmax(dim=-1)
            pp, _ = pp_estimate_eigen(used_debias_logits, used_debias_labels)
            ensemble_logits = ensemble_logits - torch.log(pp + 1e-12).cuda()
            print('debias acc', n, cls_acc(ensemble_logits, test_labels), torch.norm(pp - pp_orig, p=1))
            used_debias_logits = ensemble_logits.clone()
            used_debias_labels = used_debias_logits.argmax(dim=-1)
            pp_orig = pp
            n += 1

def main():
    # Load config file
    args = get_arguments()
    assert (os.path.exists(args.config))
    
    cfg = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)

    # Load cfg for conditional prompt.
    print("\nRunning configs.")
    print(cfg, "\n")

    # CLIP
    clip_model, preprocess = clip.load(cfg['backbone'])
    clip_model.eval()
    for p in clip_model.parameters():
        p.requires_grad = False
    
    print("Preparing dataset.")
    global train_loader_F
    global test_features, test_labels
    global val_features, val_labels
    if cfg['dataset'] != "imagenet":
        dataset = build_dataset(cfg, cfg['dataset'], cfg['root_path'], cfg['shots']) 

        with open(cfg['gpt3_prompt_file']) as f:
            gpt3_prompt = json.load(f)
            gpt_fea_dct = gpt_llm_fea(dataset.classnames, gpt3_prompt, clip_model.float(), dataset.template)
        with open(cfg['prompt_cafo']) as f:
            cafo_prompt = json.load(f)
            cafo_fea_dct = gpt_llm_fea(dataset.classnames, cafo_prompt, clip_model.float(), dataset.template)

        merged_dict = defaultdict(list)
        merged_dict = {key: torch.cat([gpt_fea_dct[key], cafo_fea_dct[key]]) for key in gpt_fea_dct}
        
        del gpt_fea_dct
        del cafo_fea_dct
        print('test')

        test_loader = build_data_loader(data_source=dataset.test, batch_size=256, is_train=False, tfm=preprocess, shuffle=False)
        val_loader = build_data_loader(data_source=dataset.val, batch_size=256, is_train=False, tfm=preprocess, shuffle=False)

        test_features, test_labels = pre_load_features(cfg, "test", clip_model, test_loader)
        val_features, val_labels = pre_load_features(cfg, "val", clip_model, val_loader)

    else:
        dataset = ImageNet(cfg, cfg['root_path'], cfg['shots'], preprocess)

        with open(cfg['gpt3_prompt_file']) as f:
                gpt3_prompt = json.load(f)
                gpt_fea_dct = gpt_llm_fea(dataset.classnames, gpt3_prompt, clip_model.float(), dataset.template)

        with open(cfg['prompt_cafo']) as f:
            cafo_prompt = json.load(f)
            cafo_fea_dct = gpt_llm_fea(dataset.classnames, cafo_prompt, clip_model.float(), dataset.template)

        merged_dict = defaultdict(list)
        merged_dict = {key: torch.cat([gpt_fea_dct[key], cafo_fea_dct[key]]) for key in gpt_fea_dct}
        
        del gpt_fea_dct
        del cafo_fea_dct
        test_loader = torch.utils.data.DataLoader(dataset.test, batch_size=64, num_workers=8, shuffle=False)

        test_features, test_labels = pre_load_features(cfg, "test", clip_model, test_loader)   
        val_features, val_labels = test_features, test_labels    

    clip_weights, _ = clip_classifier(dataset.classnames, dataset.template, clip_model.float())   
    llm_fea = torch.stack([merged_dict[name.replace('_',' ')] for name in dataset.classnames])


    notune_acc = run(cfg, clip_weights, clip_model, llm_fea)
    
if __name__ == '__main__':
    main()
