import numpy as np
import pandas as pd

# taken from https://github.com/nghiaho12/rigid_transform_3D/blob/master/rigid_transform_3D.py
# "Least-Squares Fitting of Two 3-D Point Sets", Arun, K. S. and Huang, T. S. and Blostein, S. D, IEEE Transactions on Pattern Analysis and Machine Intelligence, Volume 9 Issue 5, May 1987
# Input: expects 3xN matrix of points
# Returns R,t
# R = 3x3 rotation matrix
# t = 3x1 column vector

def rigid_transform_3D(A, B, correct_reflection=True):
    assert A.shape == B.shape

    num_rows, num_cols = A.shape
    if num_rows != 3:
        raise Exception(f"matrix A is not 3xN, it is {num_rows}x{num_cols}")

    num_rows, num_cols = B.shape
    if num_rows != 3:
        raise Exception(f"matrix B is not 3xN, it is {num_rows}x{num_cols}")

    # find mean column wise
    centroid_A = np.mean(A, axis=1)
    centroid_B = np.mean(B, axis=1)

    # ensure centroids are 3x1
    centroid_A = centroid_A.reshape(-1, 1)
    centroid_B = centroid_B.reshape(-1, 1)

    # subtract mean
    Am = A - centroid_A
    Bm = B - centroid_B

    H = Am @ np.transpose(Bm)

    # sanity check
    #if linalg.matrix_rank(H) < 3:
    #    raise ValueError("rank of H = {}, expecting 3".format(linalg.matrix_rank(H)))

    # find rotation
    U, S, Vt = np.linalg.svd(H)
    R = Vt.T @ U.T

    # special reflection case
    if np.linalg.det(R) < 0 and correct_reflection:
        print("det(R) < R, reflection detected!, correcting for it ...")
        Vt[2,:] *= -1
        R = Vt.T @ U.T

    t = -R @ centroid_A + centroid_B

    return R, t

def compute_RMSD(a, b):
    # correct rmsd calculation.
    return np.sqrt((((a-b)**2).sum(axis=-1)).mean())

def kabsch_RMSD(new_coords, coords):
    out = new_coords.T
    target = coords.T
    ret_R, ret_t = rigid_transform_3D(out, target, correct_reflection=False)
    out = (ret_R@out) + ret_t
    return compute_RMSD(target.T, out.T)

def below_threshold(x, threshold=5):
    return 100 * (x < threshold).sum() / len(x)

# custom description function.
def custom_description(data):
    t1 = data
    t2 = t1.describe()
    t3 = t1.iloc[:,1:].apply(below_threshold, threshold=2, axis=0).reset_index(name='2A').set_index('index').T
    t31 = t1.iloc[:,1:].apply(below_threshold, threshold=5, axis=0).reset_index(name='5A').set_index('index').T
    t32 = t1.iloc[:,1:].median().reset_index(name='median').set_index('index').T
    t4 = pd.concat([t2, t3, t31, t32]).loc[['mean', '25%', '50%', '75%', '5A', '2A', 'median']]
    t5 = t4.T.reset_index()
    t5[['Methods', 'Metrics']] = t5['index'].str.split('_', 1, expand=True)
    t6 = pd.pivot(t5, values=['mean', 'median', '25%', '50%', '75%', '5A', '2A'], index=['Methods'], columns=['Metrics'])
    t6_col = t6.columns
    t6.columns = t6_col.swaplevel(0, 1)
    t7 = t6[sorted(t6.columns)]
    my_MultiIndex = [
                (    'RMSD',  'mean'),
                (    'RMSD',   '25%'),
                (    'RMSD',  '50%'),
                (    'RMSD',   '75%'),
                (    'RMSD',  '5A'),
                (    'RMSD', '2A'),
                ('COM_DIST',  'mean'),
                ('COM_DIST',   '25%'),
                ('COM_DIST',  '50%'),
                ('COM_DIST',   '75%'),
                ('COM_DIST',  '5A'),
                ('COM_DIST', '2A'),
                (  'KABSCH',  'mean'),
                (  'KABSCH',   'median'),
                ]
    t8 = t7[my_MultiIndex]

    my_MultiIndex_fancy = [
                (    'Ligand RMSD $\downarrow$', ' ', 'mean'),
                (    'Ligand RMSD $\downarrow$', 'Percentiles $\downarrow$', '25%'),
                (    'Ligand RMSD $\downarrow$', 'Percentiles $\downarrow$',  '50%'),
                (    'Ligand RMSD $\downarrow$', 'Percentiles $\downarrow$',   '75%'),
                (    'Ligand RMSD $\downarrow$', r'% Below Threshold $\uparrow$',  '5A'),
                (    'Ligand RMSD $\downarrow$', r'% Below Threshold $\uparrow$', '2A'),
                ('Centroid Distance $\downarrow$', ' ',  'mean'),
                ('Centroid Distance $\downarrow$', 'Percentiles $\downarrow$',   '25%'),
                ('Centroid Distance $\downarrow$', 'Percentiles $\downarrow$',  '50%'),
                ('Centroid Distance $\downarrow$', 'Percentiles $\downarrow$',   '75%'),
                ('Centroid Distance $\downarrow$', r'% Below Threshold $\uparrow$', '5A'),
                ('Centroid Distance $\downarrow$', r'% Below Threshold $\uparrow$', '2A'),
                (  'KABSCH', 'RMSD $\downarrow$',  'mean'),
                (  'KABSCH', 'RMSD $\downarrow$',   'median'),
                ]

    t8.columns = pd.MultiIndex.from_tuples(my_MultiIndex_fancy)
    return t8.round(2)