import argparse
from functools import partial
import numpy as np
import torch
import dgl
from .sparsify import sparsify_fully_connnected_graph
from ..data import RoadDataset
from ..model.utils import euclidean_dist, cosine_similarity, gaussian_kernel
from ..utils import choose_device


def sparsify_args(argv_list=None):
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default='singapore')
    parser.add_argument('--data-dir', type=str, default='data')
    parser.add_argument('--match-distance', type=int, default=50,
                            help='Distance threshold to match SVI to road. (Choose from [30, 40, 50].)')
    parser.add_argument('--sim-metric', type=str, default='cosine', 
                        choices=['euclidean', 'cosine', 'gaussian_kernel'],
                        help='The similarity metric to use.')
    parser.add_argument('--num-edges', type=int, default=-1, 
                        help='Number of edges to keep. Default to use 36n log n.')
    parser.add_argument('--gpu', type=int, default=2, help='-1 for cpu')
    parser.add_argument('--dump', type=str, default='sim_graph.pkl',)
    if argv_list is None:
            args = parser.parse_args()
    else:
        args = parser.parse_args(argv_list)
    return args


def build_sparsified_sim_graph(argv: list = None):
    args = sparsify_args(argv)
    sim_metric = args.sim_metric
    assert sim_metric in ['euclidean', 'cosine', 'gaussian_kernel']
    if sim_metric == 'cosine':
        sim_func = partial(cosine_similarity, chunk_size=1000)
    elif sim_metric == 'euclidean':
        raise NotImplementedError
    else:
        raise NotImplementedError
    data = RoadDataset(args.dataset, args.data_dir, match_distance=args.match_distance) 
    svi_emb = data.svi_emb
    device = choose_device(args.gpu)
    svi_emb = svi_emb.to(device)
    sim = sim_func(svi_emb, svi_emb)
    sim = sim.cpu().numpy()
    # np.savez(f'./cache/tmp/{args.dataset}_sim_{sim_metric}.npz', sim=sim)
    num_nodes = sim.shape[0]
    if args.num_edges == -1:
        # Set the sparse graph size to be 36n log n (number of edges in the sparsified graph).
        args.num_edges = int(36 * num_nodes * np.log(num_nodes))
    cache_prefix = {'singapore': 'sg', 'nyc': 'nyc'}
    g = sparsify_fully_connnected_graph(sim, args.num_edges, cache_prefix[args.dataset])
    # Save the graph
    dgl.save_graphs(args.dump, [g])
    return g
