import matplotlib.pyplot as plt
import tick
from tick.hawkes import ModelHawkesExpKernLogLik, ModelHawkesExpKernLeastSq
from tick.hawkes import SimuHawkesExpKernels, SimuHawkesMulti, HawkesExpKern
from tick.hawkes import HawkesADM4, HawkesCumulantMatching
import numpy as np
import pandas as pd
from scipy.stats import gamma
from scipy.stats import expon
from scipy.stats import uniform
from scipy.stats import bernoulli
import pickle
import itertools
import tensorflow
import os.path

#########globals############

uni_adj_lb = 0.1
uni_adj_ub = 0.2
uni_bas_lb = 0.5
uni_bas_ub = 1.0
bern_param = 0.3
decays = 1
TH = 0.01
N = 1000
n_sim = 10
size = 1
deg = 1
handle_mle_ls = False


def precision(gamma_true,gamma_star):
    p = len(gamma_true)
    g_true = gamma_true
    g_star = gamma_star
    try:
        return np.count_nonzero((g_true*1) * (g_star*1))/np.count_nonzero(g_star*1)
    except:
        return 0

def recall(gamma_true,gamma_star):
    p = len(gamma_true)
    g_true = gamma_true
    g_star = gamma_star
    try:
        return np.count_nonzero((g_true*1) * (g_star*1))/np.count_nonzero(g_true*1)
    except:
        return 0

def F1_score(gamma_true,gamma_star):
    p = precision(gamma_true,gamma_star)
    r = recall(gamma_true,gamma_star)
    try:
        return 2*p*r/(p+r)
    except:
        return 0

metric = F1_score


############MDLH#################

def generate_data_uniform_multi(end_time, decays, gamma, n_sim):
    p = len(gamma)
    adjacency =  gamma * np.random.uniform(low=uni_adj_lb, high=uni_adj_ub,size=(p,p))
    #print(np.max(np.abs(np.linalg.eigvals(adjacency))))
    #print(np.count_nonzero(adjacency) - p)
    baseline = np.random.uniform(low=uni_bas_lb, high=uni_bas_ub,size=p)
    true = SimuHawkesExpKernels(
        adjacency=adjacency, decays=decays, baseline=baseline,
        end_time=end_time*2, verbose=False)
    obs = SimuHawkesMulti(true, n_simulations=n_sim)
    obs.simulate()
    data = obs.timestamps.copy()
    for j in range(n_sim):
        for i in range(p):
            data[j][i] = data[j][i][data[j][i] > end_time] - end_time
    #print(len(data[0][0]))
    return data,baseline,adjacency

def generate_gamma(p):
    gamma = bernoulli(bern_param).rvs((p,p))
    for i in range(p):
        gamma[i][i] = 1 
    gamma = gamma == 1
    return gamma


def generate_sample_multi(p,n_sim,T):
    gamma_true= generate_gamma(p)
    data_true,baseline_true,adjacency_true = generate_data_uniform_multi(end_time=T,decays=decays,gamma = gamma_true, n_sim=n_sim)
    return data_true,gamma_true,adjacency_true,baseline_true

def theta_hat_eval(adj,bas,gamma):
    p = len(gamma)
    a = np.zeros((p,p))
    b = np.zeros(p)
    for i in range(p):
        s = 0
        for x in gamma[i]:
            s = s*2 + x
        u = 0
        for j in range(i):
            if (gamma[i][j]):
                u+=1
        b[i] = bas[s][u]
        w = 0
        for j in range(p):
            if (gamma[i][j] == 1):
                a[i][j] = adj[s][u][w]
                w += 1
    coeffs = np.concatenate([b,a.flatten()]) #p+p^2 entries
    coeffs[np.nonzero(np.sum(np.vstack((a,b.reshape(1,-1))), axis=1))[0]] = 1e-100 #if sum of influence is 0, make baseline almost 0
    return coeffs

def q(theta_hat,theta,data,gamma_hat,gamma):
    #theta l
    L = ModelHawkesExpKernLogLik(decay=decays,n_threads=0)
    L.fit(data);
    pr_hat = np.exp(-L.loss(coeffs=theta_hat)) #this uses true theta?
    pr = np.exp(-L.loss(coeffs=theta))
    return (pr_hat/pr)

def COMP(gamma,prep):
    #requires: theta_hat_eval, q
    #theta list: list of true (weight matrix, baseline)

    prep_data_list,prep_adj_list,prep_bas_list,prep_gamma_list,prep_theta_list = prep #sample model paramter values
    x = []
    N = len(prep_adj_list)
    for i in range(N):
        theta_hat = theta_hat_eval(adj=prep_adj_list[i],bas = prep_bas_list[i],gamma=gamma) #sample data given parameter values, 
        u = q(theta_hat=theta_hat,theta=prep_theta_list[i],data = prep_data_list[i],gamma_hat=None,gamma=None)#find MDL (ML?) estimator
         #gamma and gamma_hat not used, replace by None
        x.append(u)
    #print(np.mean(x))
    #print(np.std(x))
    return np.log(np.mean(x))

def run_prep(p,N,n_sim,T):
    #generates data! TODO: change to accept data. and compute COMP.
    #requires: COMP
    data_list = []
    adj_list = []
    bas_list = []
    gamma_list = []
    theta_list = []
    c = 0
    while(c < N):
        print("data")
        try:
            data,gamma,adjacency,baseline = generate_sample_multi(p,n_sim,T)
        except Exception:
            #print("Oops! large spr")
            continue
        theta = np.concatenate([baseline,adjacency.flatten()])
        print("fit")
        for d in data:
            try:
                adj,bas = est_eval(d)
            except Exception:
                #print("Oops! can't fit")
                break
            c+=1
            print(c)
            data_list.append(d)
            adj_list.append(adj)
            bas_list.append(bas)
            gamma_list.append(gamma)
            theta_list.append(theta)
            if (c == N):
                break
                
    prep = [data_list,adj_list,bas_list,gamma_list,theta_list] #TODO: change to data
    
    COMP_list = []
    for i in range(p):
        COMP_list.append([])
        for s in range(2**(p-1)):
            r = np.zeros(p)
            k = s
            for j in range(p):
                if (p-j-1==i):
                    continue
                r[p-j-1] = k%2
                k = k//2
            r[i] = 1
            r = r==1
            #print(r*1)
            gamma = np.diag(np.ones(p)) == 1
            gamma[i] = r
            C = COMP(gamma,prep)
            COMP_list[i].append(C)
            
        
    with open('COMP_'+str(p)+'_'+str(T)+'.pkl', 'wb') as output:
        pickle.dump(COMP_list, output, pickle.HIGHEST_PROTOCOL)


def estimate(data,s):
    p = len(data)
    r = np.zeros(p)
    for i in range(p):
        r[p-i-1] = s%2
        s = s//2
    r = r==1
    dat = [data[i] for i in range(p) if (r[i])]
    ls = HawkesExpKern(decays, penalty='none',gofit = 'least-squares',step = 10, max_iter=10000000, tol = 1e-5)
    ls.fit(dat)
    mle = HawkesExpKern(decays, penalty='none', gofit = 'likelihood', step = 1, max_iter=10000000, tol = 1e-5, solver = 'gd')
    if handle_mle_ls:
        try:
            mle.fit(events=dat,start = ls.coeffs+1e-5)
        except RuntimeError as e:
            print("RuntimeError:", e.args[0])
            print("Estimating maximum likelihood failed, setting maximum likelihood estimate to least-squares estimate.\n")
            mle = ls
    else:
        mle.fit(events=dat,start = ls.coeffs+1e-5)
    return mle


def est_eval(data):
    #requires: estimate
    p = len(data)
    adj = [0]
    bas = [0]
    for s in range(1,2**p):
        #print(s)
        try:
            mle = estimate(data,s)
            adj.append(mle.adjacency)
            bas.append(mle.baseline)
            
        except RuntimeError as e:
            print("RuntimeError:", e.args[0])
            print("Estimating maximum likelihood failed, setting maximum likelihood estimate to almost zeros (1e-100).\n")
            adj.append(np.zeros((p,p))+1e-100)
            bas.append(np.zeros((p))+1e-100)
    return adj,bas


def load_COMP(p,T):
    with open('COMP_'+str(p)+'_'+str(T)+'.pkl', 'rb') as input:
        li = pickle.load(input)
    return li

def MDLH(data,T,hyperparams=None):
    #requires : est_eval, load_COMP which needs us to have run run_prep?,  theta_hat_eval
    p = len(data)
    adj, bas = est_eval(data)
    L = ModelHawkesExpKernLogLik(decay=decays,n_threads=0)
    L.fit(data);
    COMP_list = load_COMP(p,T)
    gamma_hat = np.diag(np.ones(p)) == 1
    for i in range(p):
        score = 1e100
        r_hat = []
        for s in range(2**(p-1)):
            r = np.zeros(p)
            k = s
            for j in range(p):
                if (p-j-1==i):
                    continue
                r[p-j-1] = k%2
                k = k//2
            r[i] = 1
            r = r==1
            #print(r*1)
            gamma = np.diag(np.ones(p)) == 1
            gamma[i] = r
            theta = theta_hat_eval(adj=adj,bas=bas,gamma=gamma)
            NML = L.loss(theta) + COMP_list[i][s]
            if (NML<score):
                score = NML
                r_hat = r
        gamma_hat[i] = r_hat
    return gamma_hat


##################MDL Sparse#######################


def generate_gamma_sparse(p,deg):
    gamma = np.diag(np.ones(p))
    for i in range(p):
        u = np.random.choice(range(0,deg+1))
        r = np.zeros(p-1)
        r[:u] = 1
        np.random.shuffle(r)
        o = 0
        for j in range(p-1):
            if (j==i):
                o = 1
            else:
                gamma[i][j+o] = r[j]
    gamma = gamma == 1
    return gamma


def generate_sample_multi_sparse(p,n_sim,T,deg):
    gamma_true= generate_gamma_sparse(p,deg)
    data_true,baseline_true,adjacency_true = generate_data_uniform_multi(end_time=T,decays=decays,gamma = gamma_true, n_sim=n_sim)
    return data_true,gamma_true,adjacency_true,baseline_true


def est_eval_sparse(data,deg):
    p = len(data)
    adj = dict()
    bas = dict()
    tot = []
    for u in range(1,deg+2):
        for w in itertools.combinations(range(p),u):
            tot.append(w)
    for w in tot:
        s = 0
        for c in w:
            s += 2**c
        try:
            mle = estimate(data,s)
            #TODO: change handling to here
            #a = mle.adjacency
            #b = mle.baseline
            #print(a.shape, b.shape)
            #b[np.nonzero(np.sum(np.vstack((a, b.reshape(1,-1))), axis=1)<=0)[0]] = -np.sum(a, axis=1)+1e-100
            adj[s] = mle.adjacency
            bas[s] = mle.baseline
        except RuntimeError as e:
            print("RuntimeError:", e.args[0])
            print("Estimating maximum likelihood failed, setting maximum likelihood estimate to almost zeros (1e-100).\n")
            adj[s] = np.zeros((p,p))+1e-100
            bas[s] = np.zeros((p))+1e-100
    return adj,bas


def run_prep_sparse(p,N,n_sim,T,deg):
    data_list = []
    adj_list = []
    bas_list = []
    gamma_list = []
    theta_list = []
    c = 0
    while(c < N):
        print("data")
        try:
            data,gamma,adjacency,baseline = generate_sample_multi_sparse(p,n_sim,T,deg)
        except Exception:
            #print("Oops! large spr")
            continue
        theta = np.concatenate([baseline,adjacency.flatten()])
        print("fit")
        for d in data:
            try:
                adj,bas = est_eval_sparse(d,deg)
            except Exception:
                #print("Oops! can't fit")
                break
            c+=1
            print(c)
            data_list.append(d)
            adj_list.append(adj)
            bas_list.append(bas)
            gamma_list.append(gamma)
            theta_list.append(theta)
            if (c == N):
                break
                
    prep = [data_list,adj_list,bas_list,gamma_list,theta_list]
    
    tot = []
    for u in range(0,deg+1):
        for w in itertools.combinations(range(p),u):
            tot.append(w)
                
    COMP_list = []
    for i in range(p):
        COMP_list.append(dict())
        for w in tot:
            s = 0
            for c in w:
                s += 2**c
            r = np.zeros(p)
            k = s
            for j in range(p):
                if (p-j-1==i):
                    continue
                r[p-j-1] = k%2
                k = k//2
            r[i] = 1
            r = r==1
            #print(r*1)
            gamma = np.diag(np.ones(p)) == 1
            gamma[i] = r
            C = COMP(gamma,prep)
            COMP_list[i][s] = C
            
        
    with open('COMP_'+str(p)+'_'+str(T)+'_'+'sparse'+'_'+str(deg)+'.pkl', 'wb') as output:
        pickle.dump(COMP_list, output, pickle.HIGHEST_PROTOCOL)

def load_COMP_sparse(p,T,deg):
    with open('COMP_'+str(p)+'_'+str(T)+'_'+'sparse'+'_'+str(deg)+'.pkl', 'rb') as input:
        li = pickle.load(input)
    return li

def MDLH_sparse(data,T,hyperparams):
    deg = hyperparams
    p = len(data)
    adj, bas = est_eval_sparse(data,deg)
    L = ModelHawkesExpKernLogLik(decay=decays,n_threads=0)
    L.fit(data);
    COMP_list = load_COMP_sparse(p,T,deg)
    gamma_hat = np.diag(np.ones(p)) == 1
    
    tot = []
    for u in range(0,deg+1):
        for w in itertools.combinations(range(p),u):
            tot.append(w)
    
    for i in range(p):
        score = 1e100
        r_hat = []
        for w in tot:
            s = 0
            for c in w:
                s += 2**c
            r = np.zeros(p)
            k = s
            for j in range(p):
                if (p-j-1==i):
                    continue
                r[p-j-1] = k%2
                k = k//2
            r[i] = 1
            r = r==1
            #print(r*1)
            gamma = np.diag(np.ones(p)) == 1
            gamma[i] = r
            theta = theta_hat_eval(adj=adj,bas=bas,gamma=gamma)
            NML = L.loss(theta) + COMP_list[i][s]
            if (NML<score):
                score = NML
                r_hat = r
        gamma_hat[i] = r_hat
    return gamma_hat


###############MLE########################
def MLE(data,T,hyperparams):
    p = len(data)
    penalty,C = hyperparams
    ls = HawkesExpKern(decays,gofit = 'least-squares', step = 10, max_iter=10000, tol = 1e-5, solver = 'gd')
    ls.fit(data)
    est = HawkesExpKern(decays, C = C, penalty=penalty, gofit = 'likelihood', step = 10, max_iter=10000, tol = 1e-5, solver = 'gd')
    est.fit(events=data,start = ls.coeffs+1e-5)
    gamma_hat = est.adjacency > TH
    return gamma_hat

###################LS########################
def LS(data,T,hyperparams):
    p = len(data)
    penalty,C = hyperparams
    ls = HawkesExpKern(decays,gofit = 'least-squares', step = 10, max_iter=10000, tol = 1e-10, solver = 'gd', C=C)
    ls.fit(data)
    gamma_hat = ls.adjacency > TH
    return gamma_hat

####################ADM4##########################
def ADM4(data,T,hyperparams):
    C, lasso_nuclear_ratio = hyperparams
    learner = HawkesADM4(decay=decays,C=C,lasso_nuclear_ratio=lasso_nuclear_ratio)
    learner.fit(data)
    #print((learner.adjacency>TH)*1)
    gamma_hat = (learner.adjacency>TH)
    return gamma_hat

#############common#################
def save_results(results, path):
    with open(path+'results.pkl', 'wb') as output:
        pickle.dump(results, output, pickle.HIGHEST_PROTOCOL)

def load_results(path):
    with open(path+'results.pkl', 'rb') as input:
        li = pickle.load(input)
    return li

def evaluate(gamma_true_list,gamma_star_list,metric=F1_score):
    size = len(gamma_star_list)
    res = []
    for i in range(size):
        sc = metric(gamma_true_list[i],gamma_star_list[i])
        res.append(sc)
    return res

def run_method(data_list, gamma_list, method,hyperparams,metric=F1_score,if_print=False):#TODO: fix!!
    #gamma list is ground truth adjacency matrix
    size = len(data_list)
    for i in range(size):
        gamma_hat_list = []
        gamma_hat = method(data_list[0],T,hyperparams)
        gamma_hat_list.append(gamma_hat)
        sc = metric(gamma_list,gamma_hat)
    return gamma_hat_list


#################main###################


def test_method(data, gamma_true_list, metric, output_path, method='MDLH', p=7, T=1000, deg=1):
    if method=='MLE':
        penalty_list = ['l1','none','l2','elasticnet']
        C_list = [500,1000,2000,5000,10000,20000,50000,100000]
        results = load_results(output_path)
        bestest = 0
        for penalty in penalty_list:
            acc_best = 0
            for C in C_list:
                if (penalty == 'none'):
                    C = None
                hyperparams = [penalty,C]
                gamma_hat_list = run_method(data, gamma_true_list, MLE,hyperparams,metric,True)
                res = evaluate(gamma_true_list,gamma_hat_list,metric)
                acc = np.mean(res)
                if (acc > acc_best):
                    print(acc)
                    acc_best = acc
                if (penalty == 'none'):
                    break
            print(penalty + " " + str(acc_best))
            if (bestest < acc_best):
                bestest = acc_best
                results['likelihood'] = bestest

        save_results(results, output_path)
    
    elif method == 'LS':
        penalty_list = ['l1','l2','elasticnet','none']#,'nuclear']
        C_list = [1,2,5,10,20,50,100,200,500,1000,2000,5000,10000,20000,50000,100000]
        results = load_results(output_path, output_path)

        bestest = 0
        for penalty in penalty_list:
            acc_best = 0
            for C in C_list:
                if (penalty == 'none'):
                    C = None
                hyperparams = [penalty,C]
                gamma_hat_list = run_method(data, gamma_true_list, LS,hyperparams,metric,True)
                res = evaluate(gamma_true_list,gamma_hat_list,metric)
                acc = np.mean(res)
                if (acc > acc_best):
                    print(acc)
                    acc_best = acc
                if (penalty == 'none'):
                    break
            print(penalty + " " + str(acc_best))
            if (bestest < acc_best):
                bestest = acc_best
                results['least-squares'] = bestest

        save_results(results)
    elif method == 'ADM4':
        C_list = [1,2,5,10,20,50,100,200,500,1000,2000,5000,10000,20000,50000,100000]
        ratio_list = [0,0.1,0.5,0.9,1]
        results = load_results(output_path)

        best = 0
        for ratio in ratio_list:
            for C in C_list:
                hyperparams = [C,ratio]
                gamma_hat_list = run_method(data, gamma_true_list,ADM4,hyperparams,metric,True)
                res = evaluate(gamma_true_list,gamma_hat_list,metric)
                acc = np.mean(res)
                if (acc > best):
                    print(acc)
                    best = acc
                print(str(C) + ' '  + str(ratio) + " " + str(acc))

        results['ADM4'] = best

        save_results(results, output_path)
    
    elif method == 'MDLH':
        if not os.path.exists('COMP_'+str(p)+'_'+str(T)+'.pkl'):
            run_prep(p,N,n_sim,T)
        results = load_results(output_path)
        print("MDLH")
        gamma_hat_list = run_method(data, gamma_true_list,MDLH,None,metric, True)
        res = evaluate(gamma_true_list,gamma_hat_list,metric)
        acc = np.mean(res)
        results['MDLH'] = acc
        save_results(results, output_path)
    
    elif method == 'MDLH_sparse':
        if not os.path.exists('COMP_'+str(p)+'_'+str(T)+'_'+'sparse'+'_'+str(deg)+'.pkl'):
            run_prep_sparse(p,N,n_sim,T,deg)
        results = load_results(output_path)
        gamma_hat_list = run_method(data, gamma_true_list,MDLH_sparse,deg,metric,True)
        res = evaluate(gamma_true_list,gamma_hat_list,metric)
        acc = np.mean(res)
        results['MDLH_sparse'] = acc
        print(acc)
        save_results(results, output_path)
    return np.array(gamma_hat_list).astype(int)

##################generate data##################
def generate_data_uniform(end_time, decays, gamma):
    p = len(gamma)
    adjacency =  gamma * np.random.uniform(low=uni_adj_lb, high=uni_adj_ub,size=(p,p))
    #print(np.max(np.abs(np.linalg.eigvals(adjacency))))
    #print(np.count_nonzero(adjacency) - p)
    baseline = np.random.uniform(low=uni_bas_lb, high=uni_bas_ub,size=p)
    true = SimuHawkesExpKernels(
        adjacency=adjacency, decays=decays, baseline=baseline,
        end_time=end_time*2, verbose=False)
    obs = SimuHawkesMulti(true, n_simulations=1)
    obs.simulate()
    data = obs.timestamps[0].copy()
    for i in range(p):
        data[i] = data[i][data[i] > end_time] - end_time
    #print(len(data[0]))
    return data,baseline,adjacency

def generate_sample(p,T):
    gamma_true= generate_gamma(p)
    data_true,baseline_true,adjacency_true = generate_data_uniform(end_time=T,decays=decays,gamma = gamma_true)
    return data_true,gamma_true,adjacency_true,baseline_true

def generate_dataset(p,T,size, path=""):
    data_true_list = []
    gamma_true_list = []
    adjacency_true_list = []
    baseline_true_list = []
    c = 0
    while(c<size):
        #print("data")
        try:
            data_true,gamma_true,adjacency_true,baseline_true = generate_sample(p,T)
        except Exception:
            #print("Oops! large spr")
            continue
        #print("fit")
        try:
            adj_true,bas_true = est_eval(data=data_true)
        except Exception:
            #print("Oops! can't fit")
            continue
        c+=1
        print(c)
        data_true_list.append(data_true)
        gamma_true_list.append(gamma_true)
        adjacency_true_list.append(adjacency_true)
        baseline_true_list.append(baseline_true)
        
    with open(path+'dataset_'+str(p)+'_'+str(T)+'.pkl', 'wb') as output:
        pickle.dump([data_true_list,gamma_true_list,adjacency_true_list,baseline_true_list,T], output, pickle.HIGHEST_PROTOCOL)

def load_dataset(p,T, path=""):
    with open(path+'dataset_'+str(p)+'_'+str(T)+'.pkl', 'rb') as input:
        li = pickle.load(input)
    return li


###################generate data sparse########################

def generate_sample_sparse(p,T,deg):
    gamma_true= generate_gamma_sparse(p,deg)
    data_true,baseline_true,adjacency_true = generate_data_uniform(end_time=T,decays=decays,gamma = gamma_true)
    return data_true,gamma_true,adjacency_true,baseline_true


def generate_dataset_sparse(p,T,size,deg, path=""):
    data_true_list = []
    gamma_true_list = []
    adjacency_true_list = []
    baseline_true_list = []
    c = 0
    while(c<size):
        #print("data")
        try:
            data_true,gamma_true,adjacency_true,baseline_true = generate_sample_sparse(p,T,deg)
        except Exception:
            #print("Oops! large spr")
            continue
        #print("fit")
        try:
            #adj_true,bas_true = est_eval_sparse(data_true,deg)
            a = 1
        except Exception:
            #print("Oops! can't fit")
            continue
        c+=1
        print(c)
        data_true_list.append(data_true)
        gamma_true_list.append(gamma_true)
        adjacency_true_list.append(adjacency_true)
        baseline_true_list.append(baseline_true)
        
    with open(path+'dataset_'+str(p)+'_'+str(T)+'_'+'sparse'+'_'+str(deg)+'.pkl', 'wb') as output:
        pickle.dump([data_true_list,gamma_true_list,adjacency_true_list,baseline_true_list,T], output, pickle.HIGHEST_PROTOCOL)

def load_dataset_sparse(p,T,deg, path):
    with open(path+'dataset_'+str(p)+'_'+str(T)+'_'+'sparse'+'_'+str(deg)+'.pkl', 'rb') as input:
        li = pickle.load(input)
    return li

###########################################
###############Convert formats#################

def data_to_dataframe(data):
    df = pd.DataFrame(columns=['device_id', 'alarm_id', 'start_timestamp'])
    for alarm_id in range(len(data)):
        for start_timestamp in data[alarm_id]:
            df2 = pd.DataFrame({'device_id': [0], 'alarm_id': [alarm_id], 'start_timestamp': [start_timestamp]}) #TODO: add discretization parameter
            df = pd.concat((df, df2), ignore_index=True)
    #take int#
    return df

def data_from_dataframe(alarms, T=1000):
    p = len(alarms['alarm_id'].unique())
    #convert data to format
    data = [None for i in range(p)]
    #edit so that max timestamp is 1000
    dmax = alarms['start_timestamp'].max()
    for i in alarms['alarm_id'].unique():
        data[i] = np.array(sorted(alarms.loc[alarms['alarm_id'] == i]['start_timestamp'])).astype(float)*T/dmax
    return data, p

import os.path

###########################################
import argparse
if __name__ == "__main__":
    argparser = argparse.ArgumentParser()
    #Input can either be a okl file path, or 
    argparser.add_argument("-read-pkl", type=bool, help="if true read data from .pkl, if false read from csv.", default=False)
    argparser.add_argument("-p", type=int, help="number of predictors (relevant if reading from pkl)", default=7)
    argparser.add_argument("-T", type=int, help="Max Timestamp.", required=True)
    argparser.add_argument("-pkl-path", type=str, help="pkl file name", default="./synthetic_data_server/mdlh_data/dataset_7_700.pkl")
    argparser.add_argument("-alarms", type=str, help="alarms csv file path", default="./synthetic_data_server/mdlh_data/alarms_7_700_0.csv")
    argparser.add_argument("-true-graph", type=str, help="true_graph file path", default="./synthetic_data_server/mdlh_data/true_graph_7_700_0.npy")
    argparser.add_argument("-output", type=str, help="path to output", default="./output/mdlh_results/mdlh.npy")
    argparser.add_argument("-method", type=str, help="method one of: 'MLE', 'LS', 'ADM4', 'MDLH', 'MDLH_sparse'", default="MDLH")
    argparser.add_argument("-deg", type=int, help="Hyperparameter for MDLH sparse, max degree of node in graph", default=-1)
    argparser.add_argument("-handle-mle-ls", type=bool, help="Handle maximum likelihood exception by returning the least squares estimate if True, and if false (default) return almost zeros estimates (1e-100).", default=False)

    #TODO:assert output path exists
    #read data
    args = argparser.parse_args()
    handle_mle_ls = args.handle_mle_ls

    assert(args.method in ['MLE', 'LS', 'ADM4', 'MDLH', 'MDLH_sparse'])

    if args.read_pkl:
        with open(args.pkl_path, 'rb') as input:
            dataset = pickle.load(input)
        gamma_true_list = dataset[1]
        data = dataset[0]
        p=args.p
    else:
        alarms = pd.read_csv(args.alarms).sort_values(by='start_timestamp')
        gamma_true_list=np.load(args.true_graph)
        #get parameters from data
        data, p = data_from_dataframe(alarms, args.T)
        gamma_true_list = gamma_true_list.astype(bool)
        data=[data]
        gamma_true_list=[gamma_true_list]

    if args.deg == -1:
        deg = np.max(np.sum(gamma_true_list[0], axis=0))
        print("Setting degree to:", deg)
    else:
        deg = args.deg

    results = dict()
    save_results(results, args.output) #initialize results dict
    T=args.T
    
    graphs = test_method(data, gamma_true_list, metric, args.output, args.method, p, args.T, deg)
    results = load_results(args.output)
    print("results:", results)

    for i in range(len(graphs)):
        if args.read_pkl:
            np.save(args.output[:-4]+str(i)+".npy", graphs[i])
        else:
            np.save(args.output, graphs[i])
        print("true_graph:", gamma_true_list[i])
        print("estimated_graph:", graphs[i])
