import os
import pathlib
from typing import List, Optional, Tuple, Union

import torch
from torch import Tensor
import numpy as np
import pandas as pd
import yaml

import einops
from fancy_einsum import einsum


import plotly.io as pio
import plotly.express as px

# import pysvelte
from IPython.display import HTML

import plotly.graph_objs as go
import ipywidgets as widgets
from IPython.display import display


import transformer_lens.utils as utils
import transformer_lens.patching as patching
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import (
    HookedTransformer,
    HookedTransformerConfig,
    FactoredMatrix,
    ActivationCache,
)

from functools import partial

from jaxtyping import Float

if torch.cuda.is_available():
    device = int(os.environ.get("LOCAL_RANK", 0))
else:
    device = "cpu"


import pandas as pd
import plotly.express as px
# Temporarily disabled
import circuitsvis as cv


# ==============================================================================
# This file is used for visualizations used for analysis and active experimentation 

def plot_attention_heads(tensor, title="", top_n=0, range_x=[0, 2.5], threshold=0.02):
    # convert the PyTorch tensor to a numpy array
    values = tensor.cpu().detach().numpy()

    # create a list of labels for each head
    labels = []
    for layer in range(values.shape[0]):
        for head in range(values.shape[1]):
            label = f"Layer {layer}, Head {head}"
            labels.append(label)

    # flatten the values array
    flattened_values = values.flatten()

    if top_n > 0:
        # get the indices of the top N values
        top_indices = flattened_values.argsort()[-top_n:][::-1]

        # filter the flattened values and labels arrays based on the top N indices
        flattened_values = flattened_values[top_indices]
        labels = [labels[i] for i in top_indices]

        # sort the values and labels in descending order
        flattened_values, labels = zip(
            *sorted(zip(flattened_values, labels), reverse=False)
        )

    # create a dataframe with the flattened values and labels
    df = pd.DataFrame({"Logit Diff": flattened_values, "Attention Head": labels})
    flat_value_array = np.array(flattened_values)
    # print sum of all values over threshold
    print(
        f"Total logit diff contribution above threshold: {flat_value_array.sum():.2f}"
    )

    # create the plot
    fig = px.bar(
        df,
        x="Logit Diff",
        y="Attention Head",
        orientation="h",
        range_x=range_x,
        title=title,
    )
    fig.show()


def l_imshow(tensor, renderer=None, **kwargs):
    px.imshow(
        utils.to_numpy(tensor),
        color_continuous_midpoint=0.0,
        color_continuous_scale="RdBu",
        **kwargs,
    ).show(renderer)


def l_line(tensor, renderer=None, **kwargs):
    px.line(y=utils.to_numpy(tensor), **kwargs).show(renderer)


def l_scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(
        y=y, x=x, labels={"x": xaxis, "y": yaxis, "color": caxis}, **kwargs
    ).show(renderer)


def two_lines(tensor1, tensor2, renderer=None, **kwargs):
    px.line(y=[utils.to_numpy(tensor1), utils.to_numpy(tensor2)], **kwargs).show(
        renderer
    )


def get_attn_head_patterns(
    model: HookedTransformer, prompt: Union[str, List[int]], attn_heads: List[Tuple[int]]
):
    if isinstance(prompt, str):
        prompt = model.to_tokens(prompt)
    logits, cache = model.run_with_cache(prompt, remove_batch_dim=True)

    head_list = []
    head_name_list = []
    for layer, head in attn_heads:
        head_list.append(cache["pattern", layer, "attn"][head])
        head_name_list.append(f"L{layer}H{head}")
    attention_pattern = torch.stack(head_list, dim=0)
    tokens = model.to_str_tokens(prompt)

    return tokens, attention_pattern, head_name_list


def get_attn_pattern(
    model: HookedTransformer, 
    prompt: str, 
    attn_heads: List[Tuple[int]], 
    cache: ActivationCache = None,
    weighted: bool = True,
) -> Tuple[List[str], Float[Tensor, "head dest src"], List[str]]:
    if cache is None:
        _, cache = model.run_with_cache(prompt, return_type=None)
    tokens = model.to_str_tokens(prompt)
    head_list = []
    head_name_list = []
    for layer, head in attn_heads:
        attn: Float[Tensor, "dest src"] = cache["pattern", layer, "attn"][:, head, :, :].mean(
            dim=0, keepdim=False
        )
        assert torch.allclose(attn.sum(dim=-1), torch.ones_like(attn.sum(dim=-1)))
        if weighted:
            v: Float[Tensor, "src"] = cache["v", layer][:, :, head, :].norm(dim=-1).mean(dim=0, keepdim=False)
            attn: Float[Tensor, "dest src"] = einops.einsum(
                attn, v, "dest src, src -> dest src"
            )
            attn /= attn.sum(dim=-1, keepdim=True)
            assert torch.allclose(attn.sum(dim=-1), torch.ones_like(attn.sum(dim=-1)))
        head_list.append(attn)
        head_name_list.append(f"L{layer}H{head}")
    attention_pattern: Float[Tensor, "head dest src"] = torch.stack(head_list, dim=0)
    return tokens, attention_pattern, head_name_list



# def plot_attention(
#     model: HookedTransformer, prompt: str, attn_heads: List[Tuple[int]], 
#     cache: ActivationCache = None, weighted: bool = True, 
#     max_value: float = 1.0, min_value: float = 0.0
# ):
#     tokens, attention_pattern, head_name_list = get_attn_pattern(
#         model, prompt, attn_heads, cache, weighted
#     )
#     return cv.attention.attention_heads(
#         tokens=tokens, 
#         attention=attention_pattern, 
#         attention_head_names=head_name_list,
#         max_value=max_value,
#         min_value=min_value,
#     )


def scatter_attention_and_contribution(
    model,
    head,
    prompts,
    io_positions,
    s_positions,
    answer_residual_directions,
    return_vals=False,
    return_fig=False,
):

    df = []

    layer, head_idx = head
    # Get the attention output to the residual stream for the head
    logits, cache = model.run_with_cache(prompts)
    per_head_residual, labels = cache.stack_head_results(
        layer=-1, pos_slice=-1, return_labels=True
    )
    scaled_residual_stack = cache.apply_ln_to_stack(
        per_head_residual, layer=-1, pos_slice=-1
    )
    head_resid = scaled_residual_stack[layer * model.cfg.n_heads + head_idx]

    # Loop over each prompt
    for i in range(len(answer_residual_directions)):
        # Get attention values
        tokens, attn, names = get_attn_head_patterns(model, prompts[i], [head])

        # For IO
        # Get the attention contribution in the residual directions
        dot = einsum(
            "d_model, d_model -> ", head_resid[i], answer_residual_directions[i][0]
        )

        # Get the attention probability to the IO answer
        prob = attn[0, 14, io_positions[i]]
        df.append([prob, dot, "IO", prompts[i]])

        # For S
        # Get the attention contribution in the residual directions
        dot = einsum(
            "d_model, d_model -> ", head_resid[i], answer_residual_directions[i][1]
        )
        # Get the attention probability to the S answer
        prob = attn[0, 14, s_positions[i]]
        df.append([prob, dot, "S", prompts[i]])

    # Plot the results
    viz_df = pd.DataFrame(
        df, columns=[f"Attn Prob on Name", f"Dot w Name Embed", "Name Type", "text"]
    )
    fig = px.scatter(
        viz_df,
        x=f"Attn Prob on Name",
        y=f"Dot w Name Embed",
        color="Name Type",
        hover_data=["text"],
        color_discrete_sequence=["rgb(114,255,100)", "rgb(201,165,247)"],
        title=f"How Strong {layer}.{head_idx} Writes in the Name Embed Direction Relative to Attn Prob",
    )

    if return_vals:
        return viz_df
    if return_fig:
        return fig
    else:
        fig.show()

def scatter_attention_and_contribution_sentiment(
    model,
    head,
    prompts,
    positions,
    answer_residual_directions,
    return_vals=False,
    return_fig=False,
):

    df = []

    layer, head_idx = head
    # Get the attention output to the residual stream for the head
    logits, cache = model.run_with_cache(prompts)
    per_head_residual, labels = cache.stack_head_results(
        layer=-1, pos_slice=-1, return_labels=True
    )
    scaled_residual_stack = cache.apply_ln_to_stack(
        per_head_residual, layer=-1, pos_slice=-1
    )
    head_resid = scaled_residual_stack[layer * model.cfg.n_heads + head_idx]

    # Loop over each prompt
    for i in range(len(answer_residual_directions)):
        # Get attention values
        tokens, attn, names = get_attn_head_patterns(model, prompts[i], [head])

        # For IO
        # Get the attention contribution in the residual directions
        dot = einsum(
            "d_model, d_model -> ", head_resid[i], answer_residual_directions[i][0]
        )

        # Get the attention probability to the answer
        prob = attn[0, -1, positions[i]]
        sentiment = "Positive" if i%2==0 else "Negative"
        df.append([prob, dot, f"{sentiment} Sentiment", prompts[i]])

    # Plot the results
    viz_df = pd.DataFrame(
        df, columns=[f"Attn Prob on Word", f"Dot w Sentiment Embed", "Word Type", "text"]
    )
    fig = px.scatter(
        viz_df,
        x=f"Attn Prob on Word",
        y=f"Dot w Sentiment Embed",
        color="Word Type",
        hover_data=["text"],
        color_discrete_sequence=["rgb(114,255,100)", "rgb(201,165,247)"],
        title=f"How Strong {layer}.{head_idx} Writes in the Sentiment Embed Direction Relative to Attn Prob",
    )

    if return_vals:
        return viz_df
    if return_fig:
        return fig
    else:
        fig.show()


def scatter_attention_and_contribution_simple(
    model,
    head,
    prompts,
    positions,
    answer_residual_directions,
    return_vals=False,
    return_fig=False,
):

    df = []

    layer, head_idx = head
    # Get the attention output to the residual stream for the head
    logits, cache = model.run_with_cache(prompts)
    per_head_residual, labels = cache.stack_head_results(
        layer=-1, pos_slice=-1, return_labels=True
    )
    scaled_residual_stack = cache.apply_ln_to_stack(
        per_head_residual, layer=-1, pos_slice=-1
    )
    head_resid = scaled_residual_stack[layer * model.cfg.n_heads + head_idx]

    # Loop over each prompt
    for i in range(len(answer_residual_directions)):
        # Get attention values
        tokens, attn, names = get_attn_head_patterns(model, prompts[i], [head])

        # For IO
        # Get the attention contribution in the residual directions
        dot = einsum(
            "d_model, d_model -> ", head_resid[i], answer_residual_directions[i][0]
        )

        # Get the attention probability to the answer
        prob = attn[0, -1, positions[i]]
        df.append([prob, dot, f"Sentiment Prompts", prompts[i]])

    # Plot the results
    viz_df = pd.DataFrame(
        df, columns=[f"Attn Prob on Word", f"Dot w Sentiment Embed", "Word Type", "text"]
    )
    fig = px.scatter(
        viz_df,
        x=f"Attn Prob on Word",
        y=f"Dot w Sentiment Embed",
        color="Word Type",
        hover_data=["text"],
        color_discrete_sequence=["rgb(114,255,100)", "rgb(201,165,247)"],
        title=f"How Strong {layer}.{head_idx} Writes in the Sentiment Embed Direction Relative to Attn Prob",
    )

    if return_vals:
        return viz_df
    if return_fig:
        return fig
    else:
        fig.show()



def scatter_attention_and_contribution_logic(
    model,
    head,
    prompts,
    answer_residual_directions,
    return_vals=False,
    return_fig=False,
):

    df = []

    layer, head_idx = head
    # Get the attention output to the residual stream for the head
    logits, cache = model.run_with_cache(prompts)
    per_head_residual, labels = cache.stack_head_results(
        layer=-1, pos_slice=-1, return_labels=True
    )
    scaled_residual_stack = cache.apply_ln_to_stack(
        per_head_residual, layer=-1, pos_slice=-1
    )
    head_resid = scaled_residual_stack[layer * model.cfg.n_heads + head_idx]

    # Loop over each prompt
    for i in range(len(answer_residual_directions)):
        # Get attention values
        tokens, attn, names = get_attn_head_patterns(model, prompts[i], [head])

        # For IO
        # Get the attention contribution in the residual directions
        dot = einsum(
            "d_model, d_model -> ", head_resid[i], answer_residual_directions[i][0]
        )

        # Get the attention probability to the correct answer
        prob = attn[0, 16, 4]
        df.append([prob, dot, "Descriptor", prompts[i]])

        # For S
        # Get the attention contribution in the residual directions
        dot = einsum(
            "d_model, d_model -> ", head_resid[i], answer_residual_directions[i][1]
        )
        # Get the attention probability to the S answer
        prob = attn[0, 16, 2]
        df.append([prob, dot, "S", prompts[i]])

    # Plot the results
    viz_df = pd.DataFrame(
        df, columns=[f"Attn Prob on Name", f"Dot w Name Embed", "Name Type", "text"]
    )
    fig = px.scatter(
        viz_df,
        x=f"Attn Prob on Name",
        y=f"Dot w Name Embed",
        color="Name Type",
        hover_data=["text"],
        color_discrete_sequence=["rgb(114,255,100)", "rgb(201,165,247)"],
        title=f"How Strong {layer}.{head_idx} Writes in the Name Embed Direction Relative to Attn Prob",
    )

    if return_vals:
        return viz_df
    if return_fig:
        return fig
    else:
        fig.show()

update_layout_set = {
    "xaxis_range", "yaxis_range", "hovermode", "xaxis_title", "yaxis_title", "colorbar", "colorscale", "coloraxis", "title_x", "bargap", "bargroupgap", "xaxis_tickformat",
    "yaxis_tickformat", "title_y", "legend_title_text", "xaxis_showgrid", "xaxis_gridwidth", "xaxis_gridcolor", "yaxis_showgrid", "yaxis_gridwidth", "yaxis_gridcolor",
    "showlegend", "xaxis_tickmode", "yaxis_tickmode", "xaxis_tickangle", "yaxis_tickangle", "margin", "xaxis_visible", "yaxis_visible", "bargap", "bargroupgap"
}

def imshow_p(tensor, renderer=None, **kwargs):
    kwargs_post = {k: v for k, v in kwargs.items() if k in update_layout_set}
    kwargs_pre = {k: v for k, v in kwargs.items() if k not in update_layout_set}
    facet_labels = kwargs_pre.pop("facet_labels", None)
    border = kwargs_pre.pop("border", False)
    if "color_continuous_scale" not in kwargs_pre:
        kwargs_pre["color_continuous_scale"] = "RdBu"
    if "margin" in kwargs_post and isinstance(kwargs_post["margin"], int):
        kwargs_post["margin"] = dict.fromkeys(list("tblr"), kwargs_post["margin"])
    fig = px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, **kwargs_pre)
    if facet_labels:
        for i, label in enumerate(facet_labels):
            fig.layout.annotations[i]['text'] = label
    if border:
        fig.update_xaxes(showline=True, linewidth=1, linecolor='black', mirror=True)
        fig.update_yaxes(showline=True, linewidth=1, linecolor='black', mirror=True)
    # things like `xaxis_tickmode` should be applied to all subplots. This is super janky lol but I'm under time pressure
    for setting in ["tickangle"]:
      if f"xaxis_{setting}" in kwargs_post:
          i = 2
          while f"xaxis{i}" in fig["layout"]:
            kwargs_post[f"xaxis{i}_{setting}"] = kwargs_post[f"xaxis_{setting}"]
            i += 1
    fig.update_layout(**kwargs_post)
    fig.show(renderer=renderer)

def hist_p(tensor, renderer=None, **kwargs):
    kwargs_post = {k: v for k, v in kwargs.items() if k in update_layout_set}
    kwargs_pre = {k: v for k, v in kwargs.items() if k not in update_layout_set}
    names = kwargs_pre.pop("names", None)
    if "barmode" not in kwargs_post:
        kwargs_post["barmode"] = "overlay"
    if "bargap" not in kwargs_post:
        kwargs_post["bargap"] = 0.0
    if "margin" in kwargs_post and isinstance(kwargs_post["margin"], int):
        kwargs_post["margin"] = dict.fromkeys(list("tblr"), kwargs_post["margin"])
    fig = px.histogram(x=tensor, **kwargs_pre).update_layout(**kwargs_post)
    if names is not None:
        for i in range(len(fig.data)):
            fig.data[i]["name"] = names[i // 2]
    fig.show(renderer)


# TODO: Remove

def convert_title_to_filename(title: str):
    # replace spaces with dashes, remove parentheses, and make lowercase
    return title.replace(' ', '-').replace('(', '').replace(')', '').lower()

def plot_graph_metric(
        df, 
        metric, 
        perf_metric_dict, 
        title,
        left_y_title, 
        y_range, 
        x_axis_col='checkpoint', 
        log_x=True, 
        legend_font_size=16, 
        axis_label_size=16, 
        disable_title=False,
        metric_legend_name="Circuit Edges"  # Add a parameter for the custom metric legend name
    ):
    # Define axis title style
    axis_title_style = dict(size=axis_label_size)
    
    # Determine display title based on `disable_title` flag
    display_title = None if disable_title else title

    # Add a new column for the performance metric by mapping the checkpoint values using the perf_metric_dict
    df['perf_metric'] = df[x_axis_col].map(perf_metric_dict).interpolate(method='linear')

    # Create a copy of the dataframe with the 'metric' column renamed for the legend
    plot_df = df.rename(columns={metric: metric_legend_name})

    # Plot with plotly express, using the renamed dataframe
    fig = px.line(plot_df, width=1200, x=x_axis_col, y=[metric_legend_name], title=display_title, log_x=log_x)

    # Specify colors for each line and update traces
    colors = {metric_legend_name: 'lightblue', 'perf_metric': 'black'}
    for trace in fig.data:
        trace.update(line=dict(color=colors[trace.name]))
    
    # Add the performance metric line with a custom name if necessary
    fig.add_trace(
        go.Scatter(x=df[x_axis_col], y=df['perf_metric'], name='Logit Diff', mode='lines', yaxis='y2', line=dict(color=colors['perf_metric']))
    )

    # Consolidate layout updates
    fig.update_layout(
        xaxis=dict(
            title="Training Checkpoint", 
            title_font=axis_title_style
        ),
        yaxis=dict(
            range=[0, y_range], 
            title=left_y_title, 
            title_font=axis_title_style
        ),
        yaxis2=dict(
            range=[0, 6], 
            title="Logit Difference", 
            overlaying="y", 
            side="right", 
            showgrid=False, 
            title_font=axis_title_style
        ),
        legend=dict(
            font=dict(size=legend_font_size),
            title_text="Metrics",
        )
    )

    # Display and save the figure
    fig.show()
    filename = "results/plots/" + convert_title_to_filename(title) + ".pdf"
    fig.write_image(filename, format='pdf', width=800, height=400, engine="kaleido")