import os

import numpy as np

import torch

from src.models.eval import get_logits_and_labels, compute_accuracy
from src.models.modeling import ImageClassifier, wiseft_merge, gla_inference
from src.args import parse_arguments
import os


def find_sub_indices(labels, indices):
    mask = torch.zeros_like(labels, dtype=torch.bool)
    for value in indices:
        mask |= (labels == value)

    sub_indices = torch.where(mask)[0]

    return sub_indices

def breakdown_accuracy(logits, labels, q):
    n_sub = 1000 // 3
    sorted_tensor, indices = torch.sort(q, descending=True)

    head_indices = indices[:n_sub]
    med_indices = indices[n_sub:2*n_sub]
    tail_indices = indices[2*n_sub:]

    head_loc = find_sub_indices(labels, head_indices)
    med_loc = find_sub_indices(labels, med_indices)
    tail_loc = find_sub_indices(labels, tail_indices)

    top1_acc = compute_accuracy(logits, 
                                labels)
    head_acc = compute_accuracy(logits[head_loc], 
                                labels[head_loc])
    med_acc = compute_accuracy(logits[med_loc], 
                                labels[med_loc])
    tail_acc = compute_accuracy(logits[tail_loc], 
                                labels[tail_loc])

    return top1_acc, head_acc, med_acc, tail_acc


def gla(args):    
    # load zero-shot model, fine-tuned model, and estimated q
    zeroshot_checkpoint, finetuned_checkpoint, q_checkpoint = args.load

    zeroshot = ImageClassifier.load(zeroshot_checkpoint)
    zeroshot.process_images = True
    finetuned = ImageClassifier.load(finetuned_checkpoint)
    q = torch.load(q_checkpoint)['q'].detach().data

    # get logits outputs and labels of zero-shot and fine-tuned models
    zs_logits, labels = get_logits_and_labels(zeroshot, args)
    ft_logits, labels = get_logits_and_labels(finetuned, args)
    # compute gla logits
    gla_logits = gla_inference(zs_logits, ft_logits, q=q)
    # create wise-ft model and get its logits
    wiseft_model = wiseft_merge(zeroshot, finetuned)
    del zeroshot, finetuned
    wiseft_logits, labels = get_logits_and_labels(wiseft_model, args)

    top1_acc, head_acc, med_acc, tail_acc = breakdown_accuracy(zs_logits, labels, q)
    print(" * zeroshot performance "
         f" top1 acc {100*top1_acc:.2f} %"
         f" head acc {100*head_acc:.2f} %"
         f" medium acc {100*med_acc:.2f} %"
         f" tail acc {100*tail_acc:.2f} %")


    top1_acc, head_acc, med_acc, tail_acc = breakdown_accuracy(ft_logits, labels, q)
    print(" * finetune performance "
         f" top1 acc {100*top1_acc:.2f} %"
         f" head acc {100*head_acc:.2f} %"
         f" medium acc {100*med_acc:.2f} %"
         f" tail acc {100*tail_acc:.2f} %")

    top1_acc, head_acc, med_acc, tail_acc = breakdown_accuracy(wiseft_logits, labels, q)
    print(" * wise-ft  performance "
         f" top1 acc {100*top1_acc:.2f} %"
         f" head acc {100*head_acc:.2f} %"
         f" medium acc {100*med_acc:.2f} %"
         f" tail acc {100*tail_acc:.2f} %")

    top1_acc, head_acc, med_acc, tail_acc = breakdown_accuracy(gla_logits, labels, q)
    print(" * gla      performance "
         f" top1 acc {100*top1_acc:.2f} %"
         f" head acc {100*head_acc:.2f} %"
         f" medium acc {100*med_acc:.2f} %"
         f" tail acc {100*tail_acc:.2f} %")


if __name__ == '__main__':
    args = parse_arguments()
    gla(args)
