#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Jan 23 09:27:38 2024

@author: anonymous
"""

import pickle
import numpy as np
import tqdm 
import torch
import torch_geometric as tg


###############################################################################


def build_torch_graph_dataset(graphs,label,outfile,augment):
    """ Reads the graphs produced by graphcode from files in a folder,
        transforms them into a list of torch_geometric graphs and writes
        the list to a file.

	Parameters
	----------
	graphs : String 
			path of the folder containing the graphs
	label : String 
			path of the file containing the class labels
    	outfile : String 
    			name of the output file
	augment : Bool 
			if true augments the vertex attributes with the multiplicative 
            and additive persistence of poitnts in persistence diagram.
    """    
    y=pickle.load(open(label,'rb'))
    graph_data=[]
    
    print('\n')
    print('Read data','\n')
    
    for i in tqdm.tqdm(range(len(y))):
        num=int(np.loadtxt(graphs+'/'+str(i)+'.txt',max_rows=1))
        bars=np.loadtxt(graphs+'/'+str(i)+'.txt',skiprows=1, max_rows=num)
        edges=np.loadtxt(graphs+'/'+str(i)+'.txt',skiprows=num+1).T
        
        if len(bars.shape)>1 and len(edges)>0:       
            vert=[]
            for j in range(len(bars)):
                if augment==True:
                    vert.append([bars[j,0],bars[j,1],bars[j,0]/bars[j,1],bars[j,1]-bars[j,0],bars[j,2]])
                else:
                    vert.append([bars[j,0],bars[j,1],bars[j,2]])
        
            G=tg.data.Data(x=torch.tensor(vert).float(), edge_index=torch.from_numpy(edges).long(), y=torch.tensor([y[i]]).long())
            graph_data.append(G)
    
    pickle.dump(graph_data,open(outfile, 'wb'))



###############################################################################


"""Pointcloud"""

build_torch_graph_dataset('Data/Pointcloud_Graphs','Data/pointcloud_labels.txt','Data/torch_graph_dataset_pointclouds.txt',True)


"""Graph"""

# build_torch_graph_dataset('Data/Graph_Graphs','Data/graph_labels.txt','Data/torch_graph_dataset_graph.txt',True)
