import os
from torch.nn import MSELoss, Parameter
import torch
import torch.optim as optim
from sklearn.linear_model import LinearRegression
from trainkit.saving import load_object
import numpy as np
import wandb
import pandas as pd


def get_baselines(project="lr_baselines", target_var="train_loss"):
    dfb = get_project_data(project=project)
    dfb = dfb[dfb["state"] == "finished"]
    dfb["vec"] = dfb.apply(rename_row, axis=1)
    dfb["label"] = dfb.apply(label_row, axis=1)
    # dfb = dfb[dfb["epoch"] <= 200]
    dfb = dfb.loc[dfb.groupby(["vec", "width"])[target_var].idxmin()]
    return dfb


def rename_text_vec(row, key="expr"):
    v1 = r"$\left(\frac{1}{2},0,\frac{1}{2},0,\frac{1}{2},\frac{1}{2},0\right)$"
    v2 = r"$\left(\frac{2}{3},0,\frac{1}{3},0,\frac{2}{3},\frac{1}{3},0\right)$"
    v3 = r"$\left(\frac{1}{3},\frac{1}{3},\frac{1}{3},\frac{1}{3},\frac{1}{3},\frac{1}{3},\frac{1}{3}\right)$"
    v4 = r"$\left(\frac{4}{5},0,\frac{1}{5},0,\frac{4}{5},\frac{1}{5},0\right)$"
    v5 = r"$\left(\frac{1}{2},\frac{1}{2},0,\frac{1}{3},\frac{1}{3},\frac{1}{3},0\right)$"
    labels = {
        "(0.5|0|0.5|0|0.5|0.5|0)": v1,
        "(0.7|0|0.3|0|0.7|0.3|0)": v2,
        "(0.33|0.33|0.34|0.33|0.33|0.34|0.33)": v3,
        "(0.8|0.2|0|0.2|0.8|0|0.33)": v4,
        "(0.5|0.5|0|0.33|0.33|0.34|0)": v5,
    }
    if row[key] in labels.keys():
        text = labels[row[key]]
    else:
        text = row[key]
    return text


def label_row(row):
    labels = {
        "none (BMM0) Adam": "Dense",
        "btt (BMM0) Adam": "BTT",
        # "0.0-0.5-0.5-0.0-0.5-0.0-0.5 (BMM0) Adam": "BTT / Monarch",
        "0.0-0.5-0.5-0.0-0.5-0.0-0.5 (BMM0) Adam": "0.0-0.5-0.5-0.0-0.5-0.0-0.5",
        "0.5-0.5-0.0-0.0-0.5-0.5-0.0 (BMM0) Adam": "TT"
    }
    if row["vec"] in labels.keys():
        text = labels[row["vec"]]
    else:
        text = row["vec"].replace(" (BMM0) Adam", "")
    return text


def rename_row(row):
    expr = get_expr(row)
    init = get_init(row)
    opt = " Adam" if row["optimizer"] == "adamw" else ""
    text = expr + init + opt
    return text


def get_expr(row):
    if row["struct"].startswith("einsum"):
        if row["fact"] in ["up", "btt", "upvec", "bttvec"]:
            text = concat_and_format(row)
        elif row["fact"] in ["btt3vec"]:
            text = row["theta"]
        else:
            text = row["expr0"][:-3]
    elif row["struct"] == "ein_expr":
        text = row["expr0"]
    else:
        text = row["struct"]
    return text


def get_init(row):
    if row["init_type"] == "bmm0":
        text = " (BMM0)"
    elif row["init_type"] == "bmm1":
        text = " (BMM1)"
    elif row["init_type"] == "rsgd":
        text = " (RSGD)"
    else:
        text = "No"
    return text


def do_coeff_analysis(df, target_var="train_loss_avg"):
    aux = df.loc[df.groupby(["vec", "width"])[target_var].idxmin()]
    coeffs = aux.groupby(["vec", "seed"]).apply(compute_reg_coeffs).apply(pd.Series)
    coeffs.columns = ["inter", "coef"]
    coeffs = coeffs.sort_values(by="coef", ascending=True)
    print(coeffs.head(20))


def get_wandb_df(project, config_vars, track_vars, steps):
    api = wandb.Api()
    runs = api.runs(project)

    data = []
    for run in runs:
        all = {"name": run.name, "state": run.state}
        all.update({key: val for key, val in run.config.items() if key in config_vars})
        if steps == [-1]:
            aux = {key: val for key, val in all.items()}
            aux.update({key: val for key, val in run.summary._json_dict.items() if key in track_vars})
            data.append(aux)
        else:
            cases = run.history(keys=track_vars, pandas=False)
            num_cases = len(cases)
            if len(cases) > 0:
                sel = [ss for ss in steps if ss < num_cases]
                for step in sel:
                    aux = {key: val for key, val in all.items()}
                    aux.update(cases[step])
                    data.append(aux)

    runs_df = pd.DataFrame(data)
    return runs_df


def get_project_data(project, steps=[-1], only_finished=True):
    filepath = f"./logs/{project}.csv"
    if os.path.exists(filepath):
        df = pd.read_csv(filepath)
    else:
        track_vars = ["train_loss", "test_acc", "train_acc", "train_loss_avg", "train_acc_avg", "epoch"]
        config_vars = ["lr", "width", "depth", "seed", "struct", "optimizer", "state", "init_type", "fact", "expr"]
        config_vars += ["cola_flops", "cola_params", "expr0", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "theta"]
        config_vars += ["use_wrong_mult"]
        df = get_wandb_df(project=f"ap3635/{project}", config_vars=config_vars, track_vars=track_vars, steps=steps)
        df.to_csv(filepath, index=False)
    if only_finished:
        df = df[df["state"] == "finished"]
    df["train_loss"] = df["train_loss"].astype(float)
    df["error"] = 100. - df["train_acc_avg"]
    df["test_error"] = 100. - df["test_acc"]
    df["flops/params"] = df["cola_flops"] / df["cola_params"]
    return df


def concat_and_format(row):
    vars = [f"v{i}" for i in range(7)]
    text = "-".join([f"{str(float(row[v]))}" for v in vars])
    return text


def compute_reg_coeffs(group):
    X = np.log(group["cola_flops"].to_numpy())[:, None]
    y = np.log(group["train_loss"].to_numpy())
    reg = LinearRegression().fit(X, y)
    return reg.intercept_, float(reg.coef_[0])


def fit_scale_sk(data, offset):
    x, y = data
    y = y - offset
    reg = LinearRegression().fit(np.log(x[:, None]), np.log(y))
    theta = np.array([reg.intercept_, -reg.coef_[0]])
    return theta


def fit_scale_law(data, lr=1e-1, n_steps=1_000, tol=1e-6):
    dtype = torch.float64
    x, y = torch.tensor(data[0], dtype=dtype), torch.tensor(data[1], dtype=dtype)
    # x = torch.log(x)
    # y = torch.log(y)
    # theta = torch.abs(0.01 * torch.randn(2, dtype=dtype))
    # theta = torch.exp(0.01 * torch.randn(3, dtype=dtype))
    theta = torch.abs(1 + 0.01 * torch.randn(3, dtype=dtype))
    theta = Parameter(theta)
    opt = optim.LBFGS([theta], lr=lr)
    theta = _fit_scale_law((x, y), opt, theta, n_steps, tol)
    return theta.detach().numpy()


def _fit_scale_law(data, opt, theta, n_steps=1_000, tol=1e-6):
    x, y = data
    loss_fn = MSELoss()

    def closure():
        opt.zero_grad()
        y_pred = scale_laws_fn(theta, x)
        # y_pred = scale_laws_fn_log(theta, x)
        loss = loss_fn(y_pred, y)
        loss.backward()
        return loss

    for idx in range(n_steps):
        opt.step(closure)
        loss = closure().item()
        print(f"Step: {idx:,d} | Loss: {loss:1.5e}")
        if loss < tol:
            break
    return theta


def scale_laws_fn_log(theta, x):
    # out = torch.log(theta[0] + torch.exp(theta[1] - theta[2] * x))
    out = theta[0] - theta[1] * x
    return out


def scale_laws_fn(theta, x):
    out = theta[0] + theta[1] * x**(-theta[2])
    # out = theta[1] * x**(-theta[2])
    return out


def get_details(filepath, key, col):
    metrics = load_object(os.path.join(filepath, "metrics.pkl"))
    defaults = load_object(os.path.join(filepath, "defaults.pkl"))
    data = []
    for dt in metrics:
        if dt[-1] == col:
            data.append(dt + (defaults[key], ))
    return data


def check_cond(filepath, key, val):
    defaults = load_object(os.path.join(filepath, "defaults.pkl"))
    if defaults[key] == val:
        return True
    else:
        return False
