import scipy.io as sio
import numpy as np
import pandas as pd
from scipy.spatial import distance_matrix
from ot.gromov import gromov_wasserstein, entropic_gromov_wasserstein
from robust_gw import *
from partial_gw import *
import time

# load dataset
shape_name = 'horse10'
partial_name = 'partial3'
full = sio.loadmat('SGP_dataset/null/' + shape_name + '.mat')
partial = sio.loadmat('SGP_dataset/selected/cuts/' + shape_name + '_' + partial_name + '.mat')
init_mathces = sio.loadmat('init_point/matches_' + shape_name + '_' + partial_name + '.mat')


# convert to distance matrice
Xs = full['N']['xyz'][0][0]
Xt = partial['N']['xyz'][0][0]
Ds = distance_matrix(Xs,Xs)
Dt = distance_matrix(Xt,Xt)

# run robust gw
n = Xs.shape[0]
m = Xt.shape[0]

a = np.ones([n,1])/n; 
b = np.ones([m,1])/m; 
rho1, rho2 = 20,1.0
eta = 50.0
t1, t2 = 50.0, 50.0
tau1, tau2 = 0.5, 0.5
rgw_maxiter = 500
relative_error = 1e-5

# initialization
init_alpha = np.zeros([n,1])
matches = init_mathces['matches'].ravel()
matches = matches - 1 
init_alpha = np.ones([n,1])/n
init_beta = np.ones([m,1])/m
init_X = np.zeros([n,m])
for i in range(matches.shape[0]):
    init_X[matches[i]][i] = 1
init_X = init_X/np.sum(init_X) 
init_X = init_X + 1e-15 

# run RGW
start = time.time()
coup_rgw, obj_list, alpha, beta = robust_gw(Ds,Dt,a,b,rgw_maxiter,rho1,rho2,eta,t1,t2,tau1,tau2,relative_error,init_X,init_alpha,init_beta)
end = time.time()
total_time_rgw = end - start


# run PGW
coup_pgw, log_pgw = pu_gw_emd(C1=Ds, C2=Dt, p=a.ravel(), q=b.ravel(), nb_dummies=100, G0=init_X, log=True, max_iter=500)

# compute accuracy
gt_full = full['N']['gt'][0][0]
gt_partial = partial['N']['gt'][0][0]
gt = []
for i in range(m):
    gt.append(np.where(gt_full == gt_partial[i])[0][0])
gt = np.array(gt).ravel()

matches_rgw = np.argmax(coup_rgw, 0).ravel()
matches_pgw = np.argmax(coup_pgw, 0).ravel()
print("acc of init:" + str(np.mean(matches == gt)))
print("acc of pgw:" + str(np.mean(matches_pgw == gt)))
print("acc of rgw:" + str(np.mean(matches_rgw == gt)))
