from imports import *
from utils import load_json
from torch.utils.data import DataLoader, Dataset

from tango import step
import time
import tango
from tango.common import FromParams
import termplotlib as tpl
from weights_composer import re_get_single_component, get_ov, remove_components
import sys
from exp_steps import (
    load_dataset,
    load_model,
    DataParams,
    ModelParams,
    get_token_idx
)
from laundry_list_exp import calc_inhib_score
from transformer_lens.hook_points import HookPoint
from transformer_lens.utils import get_act_name
from my_plotly import *
import plotly.graph_objects as go
from fancy_einsum import einsum
from tango.common import det_hash
from dataclasses import dataclass


def attn_result_hook(
        hook_vals: Float[torch.Tensor, "batch pos head_index d_model"],
        hook: HookPoint,
        head_idx: int,
        pos_idxs: list, #list of ints
        new_result_vecs: Float[torch.Tensor, 'batch d_model']
    ) -> Float[torch.Tensor, "batch pos head_index d_model"]:

    hook_vals[range(len(hook_vals)), pos_idxs, head_idx] = new_result_vecs
    return hook_vals


@step(cacheable=True, deterministic=True, version='002')
def add_scaled_vec_inhib_scores(  
    model: ModelParams,
    dataset: DataParams,
    inhib_layer: int,
    inhib_head: int,
    comp_idx: int,
    mover_layer: int, 
    mover_head: int,
    scales:list,
    distractor_idx
    ) -> np.array:

    model=model.model
    model.set_use_attn_result(True)
    dataset = dataset.dataset

    comps = []

    ov = get_ov(model, inhib_layer, inhib_head)
    u, s, v = ov.svd()
    vec = v[:, comp_idx]

    inhib_scores = []
    #dataset = load_dataset(path=dataset_path, batch_size=20)
    def get_prompt(prompts, idx):
        newprompt = dict.fromkeys(prompts)
        #print("PROMPTS", prompts)
        distractors= prompts['distractors']
        for key in prompts:
            if key=='distractors':
                newprompt[key] = [d[idx] for d in distractors]
            else:
                newprompt[key] = prompts[key][idx]
        return newprompt

    for scale in scales:
        for batch in track(dataset):
            text = batch['text']
            model.reset_hooks()
            _, cache = model.run_with_cache(text, prepend_bos=True)
            if comp_idx != None:
                #print(batch)
                tokenized_text = model.to_str_tokens(text, prepend_bos=True)
                pos_idxs = []
                for prompt in text:
                    prompt_text = model.to_str_tokens(prompt, prepend_bos=True)
                    #print(prompt_text[-1])
                    pos_idxs.append(len(prompt_text)-1)

                values_to_add = vec*scale
                hook_fn = partial(attn_result_hook, pos_idxs=pos_idxs, head_idx=inhib_head, new_result_vecs=values_to_add)
                model.blocks[inhib_layer].attn.hook_result.add_hook(hook_fn)
                _, cache = model.run_with_cache(text, prepend_bos=True)
                model.reset_hooks()

            for batch_idx in range(len(text)):
                cur_prompt = get_prompt(batch, batch_idx)
                #print(cur_prompt)
                inhib_score = calc_inhib_score(model, cur_prompt, cache.apply_slice_to_batch_dim(batch_idx), mover_layer, mover_head, distract_idx=distractor_idx)
                inhib_scores.append(inhib_score.item())
                
    return np.array(inhib_scores)



@step(cacheable=True, deterministic=True, version='001')
def single_example_multiscale_attn_scores(  
    model: ModelParams,
    prompt: dict,
    inhib_layers: list,
    inhib_heads: list,
    comp_idxs: list,
    mover_layer: int, 
    mover_head: int,
    scales:list,
    ) -> list:

    model=model.model
    model.set_use_attn_result(True)

    comps = []

    vecs = []
    for inhib_layer, inhib_head, comp_idx in zip(inhib_layers, inhib_heads, comp_idxs):
        ov = get_ov(model, inhib_layer, inhib_head)
        u, s, v = ov.svd()
        vec = v[:, comp_idx]
        vecs.append(vec)

    label_attn = []
    distractors_attn = []

    text = prompt['text']
    model.reset_hooks()
    _, cache = model.run_with_cache(text, prepend_bos=True)
    #print(batch)
    tokenized_text = model.to_str_tokens(text, prepend_bos=True)
    pos_idxs = []
    pos_idxs.append(len(tokenized_text)-1)

    for vec, scale, inhib_layer, inhib_head in zip(vecs, scales, inhib_layers, inhib_heads):
        values_to_add = vec*scale
        hook_fn = partial(attn_result_hook, pos_idxs=pos_idxs, head_idx=inhib_head, new_result_vecs=values_to_add)
        model.blocks[inhib_layer].attn.hook_result.add_hook(hook_fn)
    _, cache = model.run_with_cache(text, prepend_bos=True)
    model.reset_hooks()

    attn_pat = cache['pattern', mover_layer, 'attn'][0]
    label_token = ' '+prompt['label']
    str_tokens = model.to_str_tokens(prompt['text'])
    last_tok = len(tokenized_text)-1
    label_idx = get_token_idx(tokenized_text, label_token)
    label_attn.append(attn_pat[mover_head, last_tok, label_idx].item())

    for dist in prompt['distractors']:
        dist_token = ' '+dist
        dist_token_idx = get_token_idx(tokenized_text, dist_token)
        distractors_attn.append(attn_pat[mover_head, last_tok, dist_token_idx].item())
                
    return np.array(label_attn), np.array(distractors_attn)


@dataclass
class ObjAttns:
    obj_attns: list
    query_idx: int
    objs: list
    pred_idxs: list #ranks of the objs in the final predictions

    def to_dict(self):
        return self.__dict__


def idxs_of_objs(model, ex, logits):
    objs = ex['objects']
    #print(ex['text'])
    toks =  [model.to_single_token(' '+obj) for obj in objs]
    #argsort logits
    idxs = logits.argsort(descending=True)
    #print(idxs)
    #print([model.tokenizer.decode([t]) for t in idxs[:10]])
    #get the index of the object tokens
    obj_idxs = [torch.where(idxs == t)[0].item() for t in toks]
    return obj_idxs


@step(cacheable=True, deterministic=True, version='013')
def multiscale_exp( 
    model: ModelParams,
    dataset: DataParams,#prompt: dict, #assumes that all dataset examples have the same number of objects
    inhib_layers: list,
    inhib_heads: list,
    comp_idxs: list,
    mover_layer: int, 
    mover_head: int,
    scales:list,
    ) -> list:

    model=model.model
    model.set_use_attn_result(True)

    comps = []

    vecs = []
    for inhib_layer, inhib_head, comp_idx in zip(inhib_layers, inhib_heads, comp_idxs):
        ov = get_ov(model, inhib_layer, inhib_head)
        u, s, v = ov.svd()
        vec = v[:, comp_idx]
        vecs.append(vec)

    all_outputs = []

    def get_prompt(prompts, idx):
        newprompt = dict.fromkeys(prompts)
        #print("PROMPTS", prompts)
        objs= prompts['objects']
        for key in prompts:
            if key=='objects':
                newprompt[key] = [d[idx] for d in objs]
            else:
                newprompt[key] = prompts[key][idx]
        return newprompt

    for batch in dataset.dataset:
        text = batch['text']
        model.reset_hooks()

        pos_idxs = []
        for prompt in text:
            prompt_text = model.to_str_tokens(prompt, prepend_bos=True)
            #print(prompt_text[-1])
            pos_idxs.append(len(prompt_text)-1)
            #print(pos_idxs)
        #print("~"*25)
        for vec, scale, inhib_layer, inhib_head in zip(vecs, scales, inhib_layers, inhib_heads):
            values_to_add = vec*scale
            hook_fn = partial(attn_result_hook, pos_idxs=pos_idxs, head_idx=inhib_head, new_result_vecs=values_to_add)
            model.blocks[inhib_layer].attn.hook_result.add_hook(hook_fn)
        logits, cache = model.run_with_cache(text, prepend_bos=True)
        model.reset_hooks()


        for batch_idx in range(len(text)):
            cur_prompt = get_prompt(batch, batch_idx)
            query_idx = cur_prompt['query_idx'].item()

            label_attn = []
            distractors_attn = []

            attn_pat = cache['pattern', mover_layer, 'attn'][batch_idx]
            label_token = ' '+cur_prompt['objects'][query_idx]
            #print(cur_prompt)
            str_tokens = model.to_str_tokens(cur_prompt['text'], prepend_bos=True)
            last_tok = len(str_tokens)-1

            distractors = cur_prompt['objects']
            for dist in distractors:
                dist_token = ' '+dist
                dist_token_idx = get_token_idx(str_tokens, dist_token)
                distractors_attn.append(attn_pat[mover_head, last_tok, dist_token_idx].item())


            pred_idxs = idxs_of_objs(model, cur_prompt, logits[batch_idx, last_tok])
            output = ObjAttns(distractors_attn, query_idx, cur_prompt['objects'], pred_idxs)

            all_outputs.append(output)

    return all_outputs

if __name__ == "__main__":
    mover_layer, mover_head = int(sys.argv[1]), int(sys.argv[2])
    num_objs_to_test = int(sys.argv[3])

    ws = tango.Workspace.from_url("/oscar/data/epavlick/jmerull1/weights/tango_workspace")

    model_name = 'gpt2-small'
    dataset_path = 'datasets/laundry_6item_reg1.json'
    
    dataset_params = DataParams(dataset_path, batch_size=2, extra_descriptor='3_ex_pilot')
    model_params = ModelParams(model_name)

    try:
        comp_idx = {'8.6':2, '7.3':1, '7.9':6, '8.10':1}[f'{inhib_layer}.{inhib_head}']
    except:
        print("No preset comp idx, using 0")
        comp_idx = 0
    print("COMP IDX", comp_idx)

    def accuracy(outputs):
        return 100*sum([a.pred_idxs[a.query_idx] == 0 for a in outputs])/len(outputs)

    def label_attn(outputs):
        return np.mean([a.obj_attns[a.query_idx] for a in outputs])
    
    def all_obj_attns(outputs):
        return np.mean([a.obj_attns for a in outputs], axis=0)

    def run_multiscale_exp(data_params, do_scale_fourth=False):
        scale_size = 100.
        start = -100.#-scale_size
        stop = scale_size
        interv = 10.
        all_outputs = []
        rich.print('Running multiscale exp')
        rich.print(f"Going from {start} to {stop} in intervals of {interv}")
        for scale_first in track(np.arange(start, stop+1, interv)):
            for scale_second in np.arange(start, stop+1, interv):
                for scale_third in np.arange(start, stop+1, interv):
                    #rich.print(scale_first, scale_second, scale_third)
                    if do_scale_fourth:
                        for scale_fourth in np.arange(start, stop+1, interv):
                            all_outputs_scale = multiscale_exp(
                                model=model_params,
                                dataset=data_params,
                                inhib_layers=[7,8,8,7],
                                inhib_heads=[9,6,10,3],
                                comp_idxs=[6,2,1, 1],
                                mover_layer=mover_layer,
                                mover_head=mover_head,
                                scales=[scale_first, scale_second, scale_third, scale_fourth]
                            ).result(ws)
                            rich.print([scale_first, scale_second, scale_third, scale_fourth],':\n', label_attn(all_outputs_scale), 'avg. label attn.\n', accuracy(all_outputs_scale), '% Acc.')
                            print(all_outputs_scale[0])
                            rich.print('mean obj attns', all_obj_attns(all_outputs_scale))
                            rich.print('\n~~~~~~~~~~~~~~~~~~~~~~~~~~')
                            all_outputs.append({f'{scale_first},{scale_second},{scale_third},{scale_fourth}':[a.to_dict() for a in all_outputs_scale]})
                    else:
                        all_outputs_scale = multiscale_exp(
                            model=model_params,
                            dataset=data_params,
                            inhib_layers=[7,8,8],
                            inhib_heads=[9,6,10],
                            comp_idxs=[6,2,1],
                            mover_layer=mover_layer,
                            mover_head=mover_head,
                            scales=[scale_first, scale_second, scale_third]
                        ).result(ws)
                        rich.print([scale_first, scale_second, scale_third],':\n', label_attn(all_outputs_scale), 'avg. label attn.\n', accuracy(all_outputs_scale), '% Acc.')
                        print(all_outputs_scale[0])
                        rich.print('mean obj attns', all_obj_attns(all_outputs_scale))
                        rich.print('\n~~~~~~~~~~~~~~~~~~~~~~~~~~')
                        all_outputs.append({f'{scale_first},{scale_second},{scale_third}':[a.to_dict() for a in all_outputs_scale]})
        return all_outputs

        

    for i in  [num_objs_to_test]:#[20,7,8,9,2]: #[2,9]:#[10,7,8,2,9]:#
        torch.cuda.empty_cache()
        print(f"{i} items")
        dataset_path = f'datasets/laundry_list_250_{i}objs.json'
        output_path = f'exp_site/results/laundry_list/4d_multiscale_250_{i}objs_{mover_layer}.{mover_head}.json'
        dataset_params = DataParams(dataset_path, batch_size=13)
        all_outputs = run_multiscale_exp(dataset_params, do_scale_fourth=False)
        with open(output_path, 'w') as f:
            json.dump(all_outputs, f)


    
    
        