# imports
import sys
import os
import time

import numpy as np
from tqdm import tqdm
import faiss
import torch

# debug NANs
import jax
jax.config.update("jax_debug_nans", True)

def get_knn(embeddings, dist_fn, K=30, M=32):
    """
        Use FAISS HNSW index to compute (approximate) KNNs

        Params:
            embeddings:    numpy array, shape (N, E), of embeddings
            dist_fn:       which distance function to use to compare embeddings,
                           either 'euclid' or 'cosine'
            K:             number of nearest neighbors to calculate for each query point 
            M:             from the FAISS docs - the number of neighbors used in the graph.
                           A larger M is more accurate but uses more memory.

        Returns:
            neighbors:     numpy array, shape (N, K), where neighbors[i][j] = k means that
                           embeddings[k] is the j-th nearest neighbor of embeddings[i]
            distances:     numpy array, shape (N, K), where distances[i][j] is the distance
                           between i and its j-th nearest neighbor
                           
    """

    
    embd_size = embeddings.shape[-1]

    # choose metric
    if dist_fn=='euclid':
        metric = faiss.METRIC_L2
    elif dist_fn=='cosine':
        metric = faiss.METRIC_INNER_PRODUCT
        faiss.normalize_L2(embeddings) # normalization => inner product = cos. sim.
    else:
        print("unsure which faiss distance metric corresponds to distance \
               function {}".format(dist_fn), file=sys.stderr)
        sys.exit(2)
    
    embeddings = torch.from_numpy(embeddings)

    start = time.time()
    index = faiss.IndexHNSWFlat(embd_size, M, metric)
    index.add(embeddings)
    distances, nearest_neighbors = index.search(embeddings, K+1)

    # nearest neighbor will always be itself, so remove this
    distances = distances[:,1:]
    nearest_neighbors = nearest_neighbors[:,1:] 
    end = time.time()

    print("The time to build the faiss KNN index is ", end-start, file=sys.stderr)
    sys.stdout.flush()

    return nearest_neighbors, distances

def mask_types(name, mask_char = '#'):

    """
        mask types in function names

        params:
            name:         original function name (string)
            mask_char:    which character/string to use in the mask, default '#'

        returns:
            masked_name:  masked function name (string)
    """
    
    masked_name = ''
    mask_on = [False]

    try:
        open_count = name.count('<')
        closed_count = name.count('>')
        operator_count = name.count('operator')
        
        for i, char in enumerate(name):
            if char == '<':
                masked_name += char 
                if (name[i+1:i+8] == 'lambda_'): # apparently out-of-range slicing is handled gracefully in python
                    mask_on.append(False)
                elif open_count > closed_count:
                    if ((name[i-8:i] == 'operator') or (name[i-9:i] == 'operator<')):
                        open_count -= 1
                        continue
                    else:
                        mask_on.append(True)
                        if name[i+1:i+2] != '<':
                            masked_name += mask_char 
                else:
                    mask_on.append(True)
                    if name[i+1:i+2] != '<':
                        masked_name += mask_char                

            elif char == '>':
                masked_name += char
                if closed_count - open_count > 0:
                    if ((name[i-8:i] == 'operator') 
                        or (name[i-9:i] == 'operator-')
                        or (name[i-9:i] == 'operator>')):
                        closed_count -= 1
                        continue
                    else:
                        mask_on.pop()
                        if ((name[i+1:i+2] not in ['<', '>', ',']) and (mask_on[-1])):
                            masked_name += mask_char
                else:
                    mask_on.pop()
                    if ((name[i+1:i+2] not in ['<', '>', ',']) and (mask_on[-1])):
                        masked_name += mask_char

            elif ((char == ',') and (mask_on[-1])):
                masked_name += char 
                if name[i+1:i+2] != '<': # avoid index out of range error
                    masked_name += mask_char
            elif not mask_on[-1]:
                masked_name += char

    except:
        print(file=sys.stderr)
        print("WARNING: Error while masking name {}, reverting to original name".format(name), file=sys.stderr)
        return name
    
    return masked_name

def normalize_fn_name(name, how):
    """
        normalize function names

        params:
            name:  function name
            how:   normalization method, either 'source', 'type', or 'all' (source and type)

        returns:
            normalized name

    """
    
    if how=='source':
        parts = name.split('\\')
        return parts[-1]
    elif how=='type':
        return mask_types(name)
    elif how=='all':
        parts = name.split('\\')
        return mask_types(parts[-1])
    else:
        raise ValueError("how={} is undefined".format(how))

def normalize_labels(labels, dataset, how, mapping={}):

    """
        normalize labels

        params:
            labels:  labels to normalize
            dataset: Dataset class with method get_name, that takes in an integer label
                     and retrieves the associated function name (string)
            how:     normalization strategy: one of 'source', 'type', or 'all'
            mapping: mapping of normalized function names to new integer labels
                     new entries are added with mapping[new_name] = len(mapping)
                     defaults to an empty dictionary

        returns:
            normalized_labels
            mapping
    """
    
    orig_shape = labels.shape
    labels = labels.flatten()
    
    new_labels = np.zeros_like(labels)

    for i, label in tqdm(enumerate(labels)):
        name = dataset.get_name(label)
        normalized_name = normalize_fn_name(name, how)
        if normalized_name in mapping:
            new_labels[i] = mapping[normalized_name]
        else:
            mapping[normalized_name] = len(mapping)
            new_labels[i] = mapping[normalized_name]

    new_labels = new_labels.reshape(orig_shape)

    return new_labels, mapping

def compute_mrr(labels, neighbors, normalize=None, dataset=None):
    """
        compute mean reciprocal rank (upper and lower bounds) 

        params:
            labels:     numpy array of shape (N,)
            neighbors:  numpy array of shape (N, K) where K is the number of neighbors

        optional arguments for normalizing labels before evaluation (Assemblage data only):
            normalize:  either None, 'source', 'type', or 'all'
            dataset:    Dataset class with method get_name, that takes in an integer label
                        and retrieves the associated function name (string)

        returns:
            upper:  MRR upper bound    
            lower:  MRR lower bound

    """

    neighbors_labels = np.take(labels, neighbors)

    if normalize is not None:
        assert dataset is not None, "must pass in a dataset to normalize labels"
        assert normalize in ['source', 'type', 'all'], "normalize must be \
                                                        one of ['family', 'type', 'all']" 
        
        neighbors_labels, mapping  = normalize_labels(neighbors_labels, dataset, normalize)
        labels, _ = normalize_labels(labels, dataset, normalize, mapping)
    
    hits = (neighbors_labels == np.repeat(np.expand_dims(labels, -1), neighbors.shape[-1], axis=1))
    any_hits = np.any(hits, axis=1)

    # compute upper and lower bounds on reciprocal rank
    true_rr = 1./(np.argmax(hits, axis=1)+1) # argmax returns the index where the first True appears
    upper_bound = 1./(neighbors.shape[-1] + 1)
    lower_bound = 0

    upper = np.mean(np.where(any_hits, true_rr, upper_bound))
    lower = np.mean(np.where(any_hits, true_rr, lower_bound))


    return upper, lower
