import collections
from typing import Optional, List, Dict, Any

import numpy as np
from tqdm import tqdm


# helpers
def normalize_frame(W, axis=0):
    """Normalize the columns of W to unit length"""
    W0 = W / np.linalg.norm(W, axis=axis, keepdims=True)
    return W0


def psd_sqrt(C):
    """Computes PSD square root"""
    d, V = np.linalg.eigh(C)
    D_sqrt = np.diag(np.sqrt(np.abs(d)))  # ensure positive eigenvals
    Csqrt = V @ D_sqrt @ V.T
    return Csqrt


def compute_error(C, Ctarget, ord=2):
    """Operator norm error."""
    return np.linalg.norm(C-Ctarget, ord=ord)


def get_g_opt(
    W,
    Css,
):
    """Compute optimal G."""
    N, K = W.shape
    In = np.eye(N)
    gram_sq_inv = np.linalg.inv((W.T @ W) ** 2)
    Css_12 = psd_sqrt(Css)
    g_opt = gram_sq_inv @ np.diag(W.T @ (Css_12 - In) @ W)
    return g_opt


def context_adaptation_experiment(
        Css_list: List[np.ndarray], 
        N: int, 
        K: int, 
        eta_w: float, 
        eta_g: float,
        batch_size: int,
        n_context_samples: int, 
        n_samples: int, 
        g0: np.ndarray,
        W0: np.ndarray, 
        alpha=1.,
        online=True,
        normalize_w=False,
        seed=None, 
        error_ord=2, 
        verbose: bool = True,
        ) -> Dict[str, Any]:
    """Online and offline algorithms."""
    rng = np.random.default_rng(seed)
    
    # make sources
    Css12_list = [psd_sqrt(Css) for Css in Css_list]
    n_contexts = len(Css_list)

    N, K = W0.shape
    W = W0.copy()

    In = np.eye(N)

    results = collections.defaultdict(list)

    T = n_contexts if n_context_samples == 0 else n_context_samples
    iterator = tqdm(range(T)) if verbose else range(T)
    g = g0.copy()
    # prepend and append the same 10 random contexts for plotting later
    pre_post = rng.integers(0, n_contexts, (10))
    contexts = np.concatenate([pre_post, rng.integers(0, n_contexts, (T-20, ) ), pre_post])
    for t in iterator:
        ctx = contexts[t]
        Css, Css12 = Css_list[ctx],  Css12_list[ctx]

        for _ in range(n_samples):

            if online:
                # draw sample and compute primary neuron steady-state
                s = Css12 @ rng.standard_normal((N, batch_size))  # sample data
                WGWT = (W * g[None, :]) @ W.T  # more efficient way of doing W @ np.diag(g) @ W.T
                F = np.linalg.solve(alpha*In + WGWT, In)  # more stable than inv
                r =  F @ s  # primary neuron steady-state

                # compute interneuron input/output and update g
                z = W.T @ r  # interneuron steady-state input
                n = g[:, None] * z   # interneuron steady-state output
                w_norm = np.linalg.norm(W, axis=0)**2  # weight norm
                dg = z**2 - w_norm[:, None]
                g = g + eta_g * np.mean(dg, -1)

                # update W
                rnT = r @ n.T / batch_size
                dW = rnT - W * g[None, :]
                W = W + eta_w * dW
                Crr = F @ Css @ F.T

            else:
                WTW = W.T @ W
                g = np.linalg.solve(WTW**2, np.diag(W.T @ Css12 @ W - WTW))
                WG = W * g[None, :]
                WGWT = WG @ W.T  # more efficient way of doing W @ np.diag(g) @ W.T
                F = np.linalg.solve(alpha*In + WGWT, In)  # more stable than inv
                Crr = F @ Css @ F.T

                # update W
                dg = 0.
                dW = (Crr @ WG) - WG
                W = W + eta_w * dW

            W = normalize_frame(W) if normalize_w else W
            results['g'].append(g)
            results['g_norm'].append(np.linalg.norm(g))
            results['W_norm'].append(np.linalg.norm(W))
            results['dg_norm'].append(np.linalg.norm(dg))
            results['dW_norm'].append(np.linalg.norm(dW))
            results['error'].append(compute_error(Crr, In, error_ord))

    results.update({
        'W0': W0,
        'W': W,
        'N': N,
        'K': K,
        'eta_w': eta_w,
        'n_samples': n_samples,
        'g0': g0,
        'W0': W0,
        'seed': seed,
    })
    return results

