"""Sarkar and lca construction utils."""

import networkx as nx
import numpy as np
import torch

from utils.math_fns import cosh, arcosh, tanh, sinh, arsinh, arctanh
from utils.poincare import MIN_NORM, get_midpoint_o, get_midpoint, hyp_distance, hyp_dist_o
from utils.tree import get_leaves_root


# ################# CIRCLE INVERSIONS ########################

def get_inversion_circle(a, c):
    """
    Computes center and radius of circle inversion that maps a to the origin
    """
    m = get_midpoint_o(a, c)
    om = m.norm(dim=-1, p=2, keepdim=True)
    oa = a.norm(dim=-1, p=2, keepdim=True)
    oc = om ** 2 / (2 * om - oa).clamp_min(MIN_NORM)
    a_norm = a / oa.clamp_min(MIN_NORM)

    # center of inversion circle
    center = oc * a_norm
    r = oc - om
    return center, r


def inversion(center, r, a):
    """
    Computes image of a by circle inversion centered at c with radius r
    Formula here: https://mphitchman.com/geometry/section3-2.html
    """
    u = a - center
    u_norm = u.norm(dim=-1, p=2, keepdim=True).clamp_min(MIN_NORM)
    scale = (r / u_norm) ** 2
    return center + scale * u

# Reflection (circle inversion of x through orthogonal circle centered at a)
def isometric_transform(a, x):
    # r   = sqrt(norm(a)^2 - big(1.))
    # return (r/norm(x - a))^2*(x-a) + a
    r2 = torch.sum(a**2, dim=-1, keepdim=True) - 1.
    u = x - a
    return r2/torch.sum(u**2, dim=-1, keepdim=True) * u + a

# center of inversion circle
def reflection_center(mu):
    # return mu/mu.norm(dim=-1, p=2, keepdim=True)**2
    return mu/torch.sum(mu**2, dim=-1, keepdim=True)

# Map x under the isometry (inversion) taking mu to origin
def reflect_at_zero(mu,x):
    a = reflection_center(mu)
    return isometric_transform(a,x)

# ################# SARKAR CONSTRUCTION ########################

def place_children(z, scaling, invert_lca):
    """Embeds children of node embedded at the origin.
    Assumes z is embedding of parent of node at the origin.
    children are at disrance scale/2 from their parent in hyperbolic metric.
    """
    theta_p = np.arccos(z[0] / np.linalg.norm(z))
    children = []
    if invert_lca:
        theta_children = [theta_p + 0.5 * np.pi, theta_p - 0.5 * np.pi]
    else:
        theta_children = [theta_p + 2 * (np.pi / 3), theta_p + 4 * (np.pi / 3)]
    for theta_c in theta_children:
        children.append(scaling * np.array([np.cos(theta_c), np.sin(theta_c)]))
    return children


def sarkar(tree, tau=1.0, invert_lca=False):
    tree_reversed = nx.reverse_view(tree)
    leaf, root = get_leaves_root(tree)
    embeddings = np.zeros((tree.number_of_nodes(), 2))
    embeddings[root] = np.zeros(2)
    scaling = (np.exp(tau) - 1) / (np.exp(tau) + 1)
    for p, (c1, c2) in nx.bfs_successors(tree, root):
        if p == root:
            embeddings[c1] = scaling * np.array([1, 0])
            embeddings[c2] = scaling * np.array([-1, 0])
        else:
            p_emb = embeddings[p]
            pp = list(tree_reversed.neighbors(p))[0]
            pp_emb = embeddings[pp]
            # inversion that maps parent to the origin
            p_emb_torch = torch.from_numpy(p_emb).view(1, -1)
            pp_emb_torch = torch.from_numpy(pp_emb).view(1, -1)
            c, r = get_inversion_circle(p_emb_torch, 1.0)
            z = inversion(c, r, pp_emb_torch).view(-1).numpy()
            c_embeddings = place_children(z, scaling, invert_lca)
            embeddings[c1] = inversion(c, r, torch.from_numpy(c_embeddings[0]).view(1, -1)).view(-1).numpy()
            embeddings[c2] = inversion(c, r, torch.from_numpy(c_embeddings[1]).view(1, -1)).view(-1).numpy()
    return embeddings


# ################# CONTINUOUS LCA CONSTRUCTION ########################

def euc_reflection(x, a):
    """
    Euclidean reflection (also hyperbolic) of x
    Along the geodesic that goes through a and the origin
    (straight line)
    """
    xTa = torch.sum(x * a, dim=-1, keepdim=True)
    norm_a_sq = torch.sum(a**2, dim=-1, keepdim=True).clamp_min(MIN_NORM)
    proj = xTa * a / norm_a_sq
    return 2 * proj - x
    # norm_a = a.norm(dim=-1, p=2, keepdim=True).clamp_min(MIN_NORM)
    # norm_proj = xTa / norm_a
    # proj = norm_proj * a / norm_a


def max_norm_first(x, y):
    """Reorders tensors (x, y) so that first tensor has minimum norm."""
    norm_x = x.norm(dim=-1, p=2, keepdim=True)
    norm_y = y.norm(dim=-1, p=2, keepdim=True)
    idx = torch.argmax(torch.cat([norm_x, norm_y], dim=-1), dim=-1).type_as(x).unsqueeze(-1)
    x, y = x + idx * (y - x), y + idx * (x - y)
    return x, y


def hyp_lca1(a, b, return_coord=True, c=1.0):
    """
    Computes projection of the origin on the geodesic between a and b
    """
    a, b = max_norm_first(a, b)
    center, r = get_inversion_circle(a, c)
    b_inv = inversion(center, r, b)
    o = torch.zeros_like(center)
    o_inv = inversion(center, r, o)
    o_inv_ref = euc_reflection(o_inv, b_inv)
    midpoint = get_midpoint(o_inv, o_inv_ref, c=c)
    _, midpoint = max_norm_first(midpoint, b_inv)
    proj = inversion(center, r, midpoint)
    if not return_coord:
        orig = torch.zeros_like(proj)
        return hyp_distance(proj, orig)
    else:
        return proj

def _halve(x):
    """ computes the point on the geodesic segment from o to x at half the distance """
    return x / (1. + torch.sqrt(1 - torch.sum(x**2, dim=-1, keepdim=True)))

def hyp_lca2(a, b, return_coord=True, c=1.0):
    """
    Computes projection of the origin on the geodesic between a and b, at scale c

    More optimized than hyp_lca1
    """
    # a, b = max_norm_first(a, b)
    # invert a to origin
    r = reflection_center(a)
    b_inv = isometric_transform(r, b)
    o_inv = a
    o_inv_ref = euc_reflection(o_inv, b_inv)
    # midpoint = get_midpoint(o_inv, o_inv_ref, c=c)
    # proj = isometric_transform(r, midpoint)
    o_ref = isometric_transform(r, o_inv_ref)
    proj = _halve(o_ref)
    if not return_coord:
        # orig = torch.zeros_like(proj)
        # return hyp_distance(proj, orig)
        return hyp_dist_o(proj)
    else:
        return proj


