#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Jan 25 10:16:41 2024

@author: anonymous
"""

from random import sample
import numpy as np
import torch
import matplotlib.pyplot as plt
import os
from scipy import sparse
from typing import Tuple, List
import argparse
import gril.gril as gril
import torch
import torch.nn as nn
from scipy.spatial import Delaunay
import pickle
import tqdm

from utils import *



###############################################################################



class MultiPersLandscapeValLayer(nn.Module):
    """ Computes GRIL landscape from GRIL inputs.

	Parameters
	----------
	res : Float 
			resolution of GRIL landscape
    hom_rank : Int
            degree of homology 
    step : Int
            GRIL step size
    l : Int
            length of worms in GRIL
    """
    def __init__(self, res, hom_rank, step, l):
        super().__init__()
        self.res = res
        self.step = step
        self.l = l
        self.sample_pts = self.sample_grid()
        self.hom_rank = hom_rank
        self.mpl = gril.MultiPers(hom_rank=hom_rank, l=l, step=2, res=res, ranks=list(range(1, 6)))
        self.mpl.set_max_jobs(40)

    def sample_grid(self):
        pts = []
        num_division = int(1.0 / self.res)
        for j in range(0, num_division, self.step):
            for i in range(0, num_division, self.step):
                pts.append((i, j))
        return pts

    def forward(self, pers_inp):
        bars = self.mpl.compute_landscape(self.sample_pts, pers_inp)
        return bars



###############################################################################



def compute_GRIL_landscapes(res,hom_rank,step,l,infile,outfile):    
    """ Computes GRIL landscapes from a file containing GRIL inputs.

	Parameters
	----------
	res : Float 
			resolution of GRIL landscape
    hom_rank : Int
            degree of homology 
    step : Int
            GRIL step size
    l : Int
            length of worms in GRIL
    infile : String 
            		path of the file containing the GRIL inputs
    outfile : String 
            		path of the file containing the GRIL landscapes
    """
    
    data=pickle.load(open('../'+infile,'rb'))
    print('\n')
    print('data loaded','\n')    
        
    lcps=[]
    shape=int(np.ceil(1/(res*step)))
    
    for i in tqdm.tqdm(range(len(data))): 
        gril_layer=MultiPersLandscapeValLayer(res=res, hom_rank=hom_rank, step=step, l=l)
        output=gril_layer(data[i])
        lcps.append(np.transpose(output[0].reshape((shape,shape,5)),(2,0,1)))
    
    pickle.dump(np.array(lcps),open('../'+outfile, 'wb'))
    

    
###############################################################################


    
def show_GRIL_landscapes(res,hom_rank,step,l,infile):    
    """ Plots a random GRIL landscapes from a file containing GRIL inputs.

	Parameters
	----------
	res : Float 
			resolution of GRIL landscape
    hom_rank : Int
            degree of homology 
    step : Int
            GRIL step size
    l : Int
            length of worms in GRIL
    infile : String 
            		path of the file containing the GRIL inputs
    """
    
    data=pickle.load(open('../'+infile,'rb'))
    print('\n')
    print('data loaded','\n')    
        
    i=np.random.randint(0,len(data))
    shape=int(np.ceil(1/(res*step)))
    
    gril_layer=MultiPersLandscapeValLayer(res=res, hom_rank=hom_rank, step=step, l=l)
    img=gril_layer(data[i])
    img=np.transpose(img[0].reshape((shape,shape,5)),(2,0,1))
    
    fig, (ax1, ax2, ax3, ax4, ax5) = plt.subplots(1, 5)
    h1_1 = img[0,:,:]
    h1_2 = img[1,:,:]
    h1_3 = img[2,:,:]
    h1_4 = img[3,:,:]
    h1_5 = img[4,:,:]
    ax1.imshow(h1_1, cmap='jet', origin='lower')
    ax1.set_title('$\lambda_1$')
    ax1.set_axis_off()
    ax2.imshow(h1_2, cmap='jet', origin='lower')
    ax2.set_title('$\lambda_2$')
    ax2.set_axis_off()
    ax3.imshow(h1_3, cmap='jet', origin='lower')
    ax3.set_title('$\lambda_3$')
    ax3.set_axis_off()
    ax4.imshow(h1_4, cmap='jet', origin='lower')
    ax4.set_title('$\lambda_4$')
    ax4.set_axis_off()
    ax5.imshow(h1_5, cmap='jet', origin='lower')
    ax5.set_title('$\lambda_5$')
    ax5.set_axis_off()
    plt.show()



###############################################################################



res=0.01
hom_rank=1
step=6
l=2
infile='Data/gril_inputs.txt'
outfile='Data/gril_landscapes.txt'

compute_GRIL_landscapes(res,hom_rank,step,l,infile,outfile)

# show_GRIL_landscapes(res,hom_rank,step,l,infile)
