import os
import glob
import h5py
import torch
import numpy as np
import argparse
from halo import Halo
from collections import defaultdict
from sklearn.cluster import KMeans
from gudhi.representations.vector_methods import Atol
from gudhi.representations import PersistenceImage
from itertools import chain


def setup_cmdline_parsing():
    generic_parser = argparse.ArgumentParser()
    group0 = generic_parser.add_argument_group('Data loading/saving arguments')
    group0.add_argument("--dgms-inp-file", type=str, default="dgms.pt")
    group0.add_argument("--vecs-out-base", type=str, default="vecs")
    
    group1 = generic_parser.add_argument_group('General arguments')
    group1.add_argument("--method", type=str, default='atol') 
    group1.add_argument("--normalize", action='store_true', default=False)  

    group2 = generic_parser.add_argument_group('ATOL vectorization arguments')
    group2.add_argument("--vec-dim", type=int, default=10) 
    group2.add_argument("--subsample", type=int, default=50000) 
    
    group3 = generic_parser.add_argument_group('PI vectorization arguments')
    group3.add_argument("--resolution", type=int, default=10) 
    group3.add_argument("--bandwidth", type=float, default=1.0) 
    
    return generic_parser


def vectorize_pi(dgms, resolution=10, bandwidth=0.7, normalize=False):
    vecs = defaultdict(list)
    if normalize:
        vecs_norm = defaultdict(list)
    else:
        vecs_norm = None
    
    for k in dgms.keys():
        T = list(chain.from_iterable(dgms[k]))
        pi=PersistenceImage(resolution=[resolution, resolution], bandwidth=bandwidth, weight=lambda x: x[1]**2).fit(T)
        for simu in dgms[k]:
            simu_adj = [a.astype('float32')[:-1] if a.astype('float32').size > 0 else np.array([[0.,0.]], dtype=np.float32) for a in simu]
            pis = torch.tensor(np.array(pi.transform(simu_adj), dtype=np.float32))
            vecs[k].append(pis.unsqueeze(0))
        
        vecs[k] = torch.cat(vecs[k])

        if normalize:
            vec_dim = resolution*resolution
            _,num_timepoints,_ = vecs[k].shape
            max_d = vecs[k].view(-1, vec_dim).max(dim=0, keepdim=True).values
            min_d = vecs[k].view(-1, vec_dim).min(dim=0, keepdim=True).values
            vecs_norm[k] = -1 + 2*(vecs[k].view(-1, vec_dim) - min_d)/(max_d - min_d)
            vecs_norm[k] = vecs_norm[k].view(-1, num_timepoints, vec_dim)
    
    return vecs, vecs_norm


def vectorize_atol(dgms, vec_dim=10, subsample=-1, normalize=False):
    T = []
    for simu in dgms:
        for dgm in simu: 
            T.append(dgm)    

    if subsample > 0:
        assert subsample < len(T)
        idx = torch.randperm(len(T))[0:subsample].numpy()
    else:
        idx = torch.randperm(len(T))

    # ensure that empty diagrams are defined as [0,0] point 
    dat = [T[j].astype('float32')[:-1] if T[j].astype('float32').size > 0 else np.array([[0.,0.]], dtype=np.float32) for j in idx]
    atol = Atol(quantiser=KMeans(n_clusters=vec_dim))
    atol.fit(X=dat)

    vecs = []
    for simu in dgms:
        simu_adj = [a.astype('float32')[:-1] if a.astype('float32').size > 0 else np.array([[0.,0.]], dtype=np.float32) for a in simu]
        vecs.append(torch.tensor([atol.transform(X=simu_adj).astype('float32')]))
            
    vecs_norm = None  
    vecs = torch.cat(vecs)
    if normalize:
        _,num_timepoints,_ = vecs.shape
        max_d = vecs.view(-1, vec_dim).max(dim=0, keepdim=True).values
        min_d = vecs.view(-1, vec_dim).min(dim=0, keepdim=True).values
        vecs_norm = -1 + 2*(vecs.view(-1, vec_dim) - min_d)/(max_d - min_d)
        vecs_norm = vecs_norm.view(-1, num_timepoints, vec_dim)
    
    return vecs, vecs_norm


def main():
    
    parser = setup_cmdline_parsing()
    args = parser.parse_args()
    print(args)
    
    spinner = Halo(spinner='dots')

    spinner.start('Loading ...')
    dgms = torch.load(args.dgms_inp_file)
    spinner.succeed()
    
    vecs = None
    spinner.start('Vectorizing dgms (with {})'.format(args.method))
    if args.method == 'atol':
        vecs = defaultdict(list)
        for key, val in dgms.items():
            vecs[key] = vectorize_atol(val, args.vec_dim, args.subsample, args.normalize)
    elif args.method == 'pi':
        tmp = vectorize_pi(dgms, resolution=args.resolution, 
                            bandwidth=args.bandwidth, 
                            normalize=args.normalize)
        vecs = defaultdict(list)
        for key, _ in dgms.items():
            if args.normalize:
                vecs[key] = (tmp[0][key], tmp[1][key])
            else:
                vecs[key] = (tmp[0][key], None)
        
    spinner.succeed()
    
    spinner.start('Saving to {}_cat_nonorm.pt'.format(args.vecs_out_base))
    vecs_cat = torch.cat(tuple([vecs[key][0] for key in vecs.keys()]),dim=2)
    torch.save(vecs_cat, args.vecs_out_base + "_cat_nonorm.pt")
    spinner.succeed()
    
    if args.normalize:
        spinner.start('Normalizing and saving to {}_cat_norm.pt'.format(args.vecs_out_base))
        vecs_cat = torch.cat(tuple([vecs[key][1] for key in vecs.keys()]), dim=2)
        torch.save(vecs_cat, args.vecs_out_base + "_cat_norm.pt")
        spinner.succeed()

if __name__ == "__main__":
    main()





