import torch

from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    GPT2LMHeadModel,
    GPTJForCausalLM,
    GPTNeoXForCausalLM,
    LlamaForCausalLM,
    MistralForCausalLM,
    AutoModelForSequenceClassification
)

def get_hyperparameters():
    hyperparameters = ["shuffle","seed", "batch_size","test_split","attack_type","iters","step_size","control_prompt"
                       ,"early_stopping","il_gen","generate_interval","num_tokens_printed",
                    "verbose","model_name","dataset_name", "temperature"]
    return hyperparameters


def create_one_hot_and_embeddings(tokens, embed_weights, model):
    if tokens is not None:
        one_hot = create_one_hot(model, tokens)
        embeddings = (one_hot @ embed_weights).data
        return one_hot, embeddings
    else:
        return None, None
    

def create_one_hot(model, tokens):
    embed_weights = get_embedding_matrix(model)
    if tokens is not None:
        B, sequence_length, dim = tokens.shape[0], tokens.shape[1], embed_weights.shape[0]

        one_hot = torch.zeros(
            B, sequence_length, dim, device=model.device, dtype=embed_weights.dtype
        )
        
        one_hot.scatter_(2, tokens.unsqueeze(2), 1)
        return one_hot
    else:
        return None


def init_attack_embeddings(model, tokenizer, control_prompt, device, repeat=0):
    embed_weights = get_embedding_matrix(model)
    attack_tokens = torch.tensor(tokenizer(control_prompt)["input_ids"], device=device)[1:]
    attack_tokens = attack_tokens.unsqueeze(0)
    if repeat > 0:
        attack_tokens = attack_tokens.repeat(repeat, 1)
    _, embeddings_attack = create_one_hot_and_embeddings(
        attack_tokens, embed_weights, model
    )
    embeddings_attack.requires_grad = True
    return embeddings_attack


def get_perplexity_from_tokens(model, tokens):
    # function resturn perplexity of a batch of tokens. The results are the same as using outputs.loss but with additional masking
    attention_mask = tokens != 0
    outputs = model(input_ids=tokens, labels=tokens, attention_mask=attention_mask) 
    perplexity = get_perplexity_from_logits(model, outputs.logits, tokens)
    
    return perplexity


def get_perplexity_from_logits(model, logits, target_tokens):
    # shift target tokens; all predictions from padded tokens are ignored in loss calculation
    mask = (target_tokens != 0)[:, :-1]
    target_tokens = target_tokens[:, 1:]
    logits = logits[:, :-1, :]

    # create one hot target
    one_hot_target = create_one_hot(model, target_tokens)
    vocab_size = logits.shape[-1]

    # calculate loss
    logits_flat = logits.reshape(-1, vocab_size)
    one_hot_target_flat = one_hot_target.reshape(-1, vocab_size)
    loss = torch.nn.functional.cross_entropy(logits_flat, one_hot_target_flat, reduction='none')
    
    # sample wise perplexity with masking
    loss_sample = loss.reshape(mask.shape[0], mask.shape[1])
    loss_sample.data[~mask] = 0
    loss_sample = loss_sample.sum(dim=1) / mask.sum(dim=1)
    perplexity = torch.exp(loss_sample)

    return perplexity


def print_result_dict(result_dict):
    exclude_from_print = ["affirmative_response", "embeddings_attack", "generated_text", "input_tokens", "target_tokens", "intermediate_layer_generation"]
    str_list = []
    for key, value in result_dict.items():
        if key in exclude_from_print:
            continue
        if isinstance(value, list):
            str_list.append(f"{key}: {value[0]}")
        else:
            str_list.append(f"{key}: {value}")
    final_str = " | ".join(str_list)
    
    # add generated text if it exists and is not trivial
    if "generated_text" in result_dict:
        generate_text = ""
        if isinstance(result_dict["generated_text"], list):
            if result_dict["generated_text"][0] != None:
                generate_text = f" | generated_text: {result_dict['generated_text'][0]}"
        elif result_dict["generated_text"] != None:
                generate_text = f" | generated_text: {result_dict['generated_text']}"
        if generate_text != "":
            final_str += f"\n==========generated_text============\n{generate_text}\n============================================\n"
    print(final_str)


def load_model_and_tokenizer(model_path, tokenizer_path=None, device="cuda:0", **kwargs):
    if "unbiased-toxic-roberta" in model_path:
        model = (AutoModelForSequenceClassification.from_pretrained(
                model_path, torch_dtype=torch.float16, trust_remote_code=True, **kwargs
            )
            .to(device)
            .eval()
        )
    elif "paper_models" in model_path:
        model = (AutoModelForCausalLM.from_pretrained(
            model_path, 
            use_flash_attention_2=True, 
            torch_dtype=torch.bfloat16, 
            trust_remote_code = True)
            .to(device)
            .eval()
        )
    else:
        model = (
            AutoModelForCausalLM.from_pretrained(
                model_path, torch_dtype=torch.float16, trust_remote_code=True, **kwargs
            )
            .to(device)
            .eval()
        )
    tokenizer_path = model_path if tokenizer_path is None else tokenizer_path
    tokenizer = load_tokenizer(tokenizer_path)

    return model, tokenizer


def load_tokenizer(tokenizer_path):
    if "paper_models" in tokenizer_path:
        tokenizer_path = "NousResearch/Llama-2-7b-chat-hf"
    
    tokenizer = AutoTokenizer.from_pretrained(
        tokenizer_path, trust_remote_code=True, use_fast=False
    )
    set_tokeninzer = 0
    
    if "oasst-sft-6-llama-30b" in tokenizer_path:
        print("Using oasst-sft-6-llama-30b tokenizer, setting bos_token_id to 1 and unk_token_id to 0.")
        tokenizer.bos_token_id = 1
        tokenizer.unk_token_id = 0
        set_tokeninzer += 1
    if "guanaco" in tokenizer_path:
        print("Using guanaco tokenizer, setting eos_token_id to 2 and unk_token_id to 0.")
        tokenizer.eos_token_id = 2
        tokenizer.unk_token_id = 0
        set_tokeninzer += 1
    if "NousResearch" in tokenizer_path:
        print("Using llama-2 tokenizer for paper_models, setting padding side to left and pad_token to unk_token.")
        tokenizer.pad_token = tokenizer.unk_token
        tokenizer.padding_side = "left"
        set_tokeninzer += 1
    if "llama-2" in tokenizer_path:
        print("Using llama-2 tokenizer, setting padding side to left and pad_token to unk_token.")
        tokenizer.pad_token = tokenizer.unk_token
        tokenizer.padding_side = "left"
        set_tokeninzer += 1
    if "falcon" in tokenizer_path:
        print("Using falcon tokenizer, setting padding side to left.")
        tokenizer.padding_side = "left"
        set_tokeninzer += 1
    if "unbiased-toxic-roberta" in tokenizer_path:
        print("Using unbiased-toxic-roberta tokenizer, setting padding side to left and pad_token to 0.")
        tokenizer.padding_side = "left"
        tokenizer.pad_token = 0
        set_tokeninzer += 1
    if not tokenizer.pad_token:
        print("Setting pad token to eos token, no specfic logic defined for this model. This might be wrong. Check padding_side and unk_token_id.")
        tokenizer.pad_token = tokenizer.eos_token
        set_tokeninzer += 1
    
    if set_tokeninzer == 0:
        raise ValueError("No tokenizer logic was set. Check logic.")
    if set_tokeninzer > 1:
        raise ValueError("Tokenizer was set more than once. Check logic.")
    
    return tokenizer


def get_embedding_matrix(model):
    # from llm-attacks
    if isinstance(model, GPTJForCausalLM) or isinstance(model, GPT2LMHeadModel):
        return model.transformer.wte.weight
    elif isinstance(model, LlamaForCausalLM):
        return model.model.embed_tokens.weight
    elif isinstance(model, GPTNeoXForCausalLM):
        return model.base_model.embed_in.weight
    elif isinstance(model, MistralForCausalLM):
        return model.model.embed_tokens.weight
    else:
        raise ValueError(f"Unknown model type: {type(model)}")


def get_attention_mask(model, input_tokens, target_tokens, embeddings_attack):
    B, len_attack = target_tokens.shape[0], embeddings_attack.shape[1]
    
    attack_mask = torch.ones((B, len_attack), dtype=bool, device=model.device)
    target_mask = target_tokens != 0
    
    # input_tokens are not used in all datasets
    if input_tokens is not None:
        input_mask = input_tokens != 0 
        attention_mask = torch.cat([input_mask, attack_mask, target_mask], dim=1)
    else:
        attention_mask = torch.cat([attack_mask, target_mask], dim=1)

    return attention_mask


def num_affirmative_response(logits_pred, target_tokens, return_sample_wise=False):
    succes_sample_wise = torch.zeros(target_tokens.shape[0], dtype=bool, device=target_tokens.device)
    # for every intermediate layer dimension check success
    for i in range(len(logits_pred)):
        tokens_pred = logits_pred[i].argmax(2)
        success = (tokens_pred == target_tokens)
        # we do not care about padding tokens and just pretend we predicted them correctly 
        success[target_tokens == 0] = True 
        succes_sample_wise += success.all(1)
    
    if return_sample_wise:
        return succes_sample_wise
    else:
        return succes_sample_wise.sum()
    
    
# TODO remove this for submission!
def base_path():
    path = None
    if path is None:
        raise ValueError("Path is not set. Please set path.")
    return path


def get_model_path(model_name):
    if "llama" in model_name.lower():
        model_path = base_path() + "llama-2/llama/"
    elif "tofu_grad_diff_forget01" == model_name.lower():
        model_path = base_path() + "paper_models/final_ft_noLORA_5_epochs_inst_lr2e-05_llama2-7b_full/checkpoint-625/1GPU_grad_diff_2e-05_forget01/checkpoint-4/"
    elif "tofu_grad_ascent_forget01" == model_name.lower():
        model_path = base_path() + "paper_models/final_ft_noLORA_5_epochs_inst_lr2e-05_llama2-7b_full/checkpoint-625/1GPU_grad_ascent_2e-05_forget01/checkpoint-4/"
    else:
        model_path = base_path()
    return model_path