from core import *
from tmps import *
import sys
from evaluation import *
#from scipy.stats import beta
import statistics
import timeit
from pr4a_wrapper import *
import numpy as np
import scipy.stats as st

def evaluate(F,S,A,kp,ka,c):
    nrev, npap = A.shape
    u_pap=np.zeros(npap)
    u_rev=np.zeros(nrev)

    for p in range(npap):
        u_pap[p]=np.dot(F[:,p],S[:,p])

    for r in range(nrev):
        for p in range(npap):
            if A[r,p]==1:
                u_rev[r]+=u_pap[p]

    uw=np.mean(u_pap)
    ew=min(u_pap)
    print(ew)
    if c==0:
    #    alpha,infin=core_check(F,S,A,kp,ka,u_rev,nrev, npap)
        alpha=1
        infin=0
    else:
        alpha=1
        infin=0

    return uw,ew, alpha, infin


#choose 1 for iclr 18
#choose 2 for cvpr 17
#choose 3 for cvpr 18

dataset = sys.argv
print(dataset[1])
it=100
ka=3
kp=3


results=[]
temp_core=np.zeros((it,4))
temp_tmps=np.zeros((it,4))
temp_mm=np.zeros((it,4))

prob_mm=0
prob_tmps=0
alpha_mm=[]
non_infin_mm=0
alpha_tmps=[]
non_infin_tmps=0
utilities_core=[]
utilities_tmps=[]
utilities_mm=[]

if dataset[1]=='1':
    dataset = 'iclr2018'
    dataset_name = 'iclr2018'
    scores = np.load('data/' + dataset + '_authorship,kp='+str(kp)+',ka='+str(ka)+'.npz', allow_pickle = True)
    A_init = scores["author_matrix"]
    S_init = scores["similarity_matrix"]
    nrev, npap = S_init.shape
    nrev, npap = A_init.shape
else:
    if  dataset[1]=='2':
        dataset_name = 'cvpr_scores'
        S_init = np.load('data/' + dataset_name + '.npy', allow_pickle = True)
    elif dataset[1]=='3':
        dataset_name = 'cvpr_scores_18'
        S_init = np.load('data/' + dataset_name + '.npy', allow_pickle = True)
        S_init=(S_init - np.amin(S_init)) / (np.amax(S_init) - np.amin(S_init)) #normalize values
    nrev, npap = S_init.shape
    Occupied=np.zeros(nrev)
    A_init=np.zeros((nrev,npap))
    for p in range(npap):
        candidates=np.flip(np.argsort(S_init[:,p]))
        for r in range(nrev):
            if Occupied[candidates[r]]<=0:
                A_init[candidates[r], p]=1
                Occupied[candidates[r]]+=1
                break


for l in range(it):
    print(l)
    start = time.time()
    ########Sample#########
    size_sample=100
    nrev, npap = A_init.shape
    sample=random.sample(range(nrev), size_sample)
    z = list(set(list(range(nrev))) - set(sample))
    A=np.delete(A_init,z, axis=0)
    S=np.delete(S_init, z, axis=0)
    no_matched=[]
    for p in range(npap):
        if sum(A[:,p])==0:
            no_matched.append(p)
    S=np.delete(S, no_matched, 1)
    A=np.delete(A, no_matched, 1)
    nrev, npap = A.shape
    u_pap=np.zeros(npap)
    F_core=core(S,A,kp,ka)
    for p in range(npap):
        u_pap[p]=np.dot(F_core[:,p],S[:,p])
    utilities_core.append(np.sort(u_pap))

    F_tmps=tmps(S,A,kp,ka)
    for p in range(npap):
        u_pap[p]=np.dot(F_tmps[:,p],S[:,p])
    utilities_tmps.append(np.sort(u_pap))

    F_mm=pr4a(S,A, kp, ka)
    for p in range(npap):
        u_pap[p]=np.dot(F_mm[:,p],S[:,p])
    utilities_mm.append(np.sort(u_pap))

    temp_core[l,:]=evaluate(F_core, S,A, kp,ka,1)
    temp_tmps[l,:]=evaluate(F_tmps, S,A,kp,ka,0 )
    if temp_tmps[l,3]==0:
        alpha_tmps.append(temp_tmps[l,2])
        non_infin_tmps+=1
    if temp_tmps[l,2]>1:
        prob_tmps+=1

    temp_mm[l:]=evaluate(F_mm, S,A,kp,ka,0 )
    if temp_mm[l,3]==0:
        alpha_mm.append(temp_mm[l,2])
        non_infin_mm+=1
    if temp_mm[l,2]>1:
        prob_mm+=1

    end = time.time()
    print(end - start)

temp=(temp_core.mean(axis=0)  )
results.append(temp)
temp=1.96*temp_core.std(axis=0)/math.sqrt(it)
results.append(temp)

temp=(temp_tmps.mean(axis=0)  )
results.append(temp)
temp=1.96*temp_tmps.std(axis=0)/math.sqrt(it)
results.append(temp)

temp=(temp_mm.mean(axis=0)  )
results.append(temp)
temp=1.96*temp_mm.std(axis=0)/math.sqrt(it)
results.append(temp)

results2=[]
results2.append(prob_tmps)
results2.append(prob_mm)

results2.append(non_infin_tmps)
if non_infin_tmps>0:
    alpha_tmps=np.array(alpha_tmps)
    results2.append(alpha_tmps.mean())
    results2.append(1.96*(alpha_tmps.std())/math.sqrt(it))

results2.append(non_infin_mm)
if non_infin_mm>0:
    alpha_mm=np.array(alpha_mm)
    results2.append(alpha_mm.mean())
    results2.append(1.96*(alpha_mm.std())/math.sqrt(it))



np.savetxt('1-kp='+str(kp)+',ka='+str(ka)+'-'+dataset_name+'.csv', results, delimiter=',', fmt='%f')
np.savetxt('2-kp='+str(kp)+',ka='+str(ka)+'-'+dataset_name+'.csv', results2, delimiter=',', fmt='%f')

np.savetxt('utilities_core.txt', utilities_core)
np.savetxt('utilities_tmps.txt', utilities_tmps)
np.savetxt('utilities_mm.txt', utilities_mm)
