import pandas as pd
import numpy as np
from utils import PROBLEM_FEATURES, ACTIONS
from mtticc.TICC_solver import TICC
from prefixspan import PrefixSpan
import random
import multiprocessing as mp
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_samples, silhouette_score
import pickle
import csv
import os
os.environ['D4RL_SUPPRESS_IMPORT_ERROR'] = '1'
from absl import app


def main(_):
	# load data
	file_name = 'ITS_withBehaviorPolicy'
	data_path = '../raw_data/{}.csv'.format(file_name)
	raw_data = pd.read_csv(data_path)

	feature_list = PROBLEM_FEATURES
	# student_state = raw_data[feature_list].values
	print('finish loading data')

	user_list = list(raw_data['userID'].unique())
	init_state = []

	for user in user_list:
	    init_state.append(list(raw_data.loc[raw_data['userID'] == user][feature_list].values[7]))
	    
	# exclude system init state: same for all students
	column_to_drop = []
	feature_to_drop = []
	for idx, f in enumerate(feature_list):
	    f_values = [s[idx] for s in init_state]
	    if f_values.count(f_values[0]) == len(f_values):
	        column_to_drop.append(idx)
	        feature_to_drop.append(f)

	# use personal veried features to cluster
	popu_feature = list(set(feature_list) - set(feature_to_drop))
	init_popu_state= []
	for user in user_list:
	    init_popu_state.append(list(raw_data.loc[raw_data['userID'] == user][popu_feature].values[1]))


	# get # of clusters
	min_c, max_c = 2, 20
	range_n_clusters = range(min_c, max_c+1)
	sil_score = [0.]*(max_c-min_c+1)

	def getSilScore(n_clusters):
	    clusterer = KMeans(n_clusters=n_clusters, random_state=10)
	    cluster_labels = clusterer.fit_predict(init_popu_state)
	    silhouette_avg = silhouette_score(init_popu_state, cluster_labels)

	    return [n_clusters,silhouette_avg]

	pool = mp.Pool(5)
	sil_score = pool.map(getSilScore, [n_clusters for n_clusters in range_n_clusters])
	pool.close()
	pool.join()    
	    
	num_clusters = sil_score[[i[1] for i in sil_score].index(max([el[1] for el in sil_score]))][0]
	print('selected num of clusters: {} !'.format(num_clusters))


	# ticc
	raw_save_path = '../raw_data/ITS_data.txt'
	np.savetxt(raw_save_path, init_popu_state, delimiter=',')

	from TICC_solver import TICC
	import numpy as np
	import sys

	fname = raw_save_path
	ticc = TICC(window_size=1, number_of_clusters=num_clusters, lambda_parameter=11e-2, beta=0, maxIters=100, threshold=2e-5,
	            write_out_file=False, prefix_string="output_folder/", num_proc=1)
	(cluster_assignment, cluster_MRFs) = ticc.fit(input_file=fname)


	np.savetxt('../cluster_data/ITS.txt', cluster_assignment, fmt='%d', delimiter=',')


	# save train data by clusters
	# data format
	# list of dictionaries:
	# List: N traj
	# dictionary: 9 elements, {'actions': An N by action dimensional array of actions (1000,6) ([[]]), NEED
	#         'infos/action_log_probs': [1000] , 
	#         'infos/qpos' (1000,9), 
	#         'infos/qvel' (1000,9), 
	#         'next_observations': An N by observation dimensional array of observations. (1000,17), NEED
	#         'observations', NEED
	#         'rewards': An N dimensional array of rewards, NEED
	#         'terminals': An N dimensional array of episode termination flags. This is true when episodes end due to termination conditions such as falling over., NEED
	#         'timeouts':  An N dimensional array of termination flags. This is true when episodes end due to reaching the maximum episode length.
	#             }

	# limit in train data
	# # set time range
	train_data = raw_data.loc[raw_data['userID'] < 201000] # est: Spring20
	train_user_list = list(train_data['userID'].unique())


	for c in range(num_clusters):
	    # get train id in cluster c
	    user_in_c = [user_list[idx] for idx, cluster_idx in enumerate(cluster) if cluster_idx == c and user_list[idx] in train_user_list]
	    train_data_in_c = train_data.loc[train_data['userID'].isin(user_in_c)]

	    seg_in_c = []
	    count = 0 # count saved user

	    for user in user_in_c:
	        # initialize local segment
	        user_seg = dict.fromkeys(['observations','actions','rewards','next_observations','terminals'])

	        target_row_idx = train_data_in_c.index[train_data_in_c['userID'] == user].tolist()
	        max_step = len(target_row_idx)
	        user_seg['observations'] = train_data_in_c.loc[target_row_idx, feature_list].values
	        user_seg['actions'] = train_data_in_c.loc[target_row_idx, ACTIONS].values
	        user_seg['rewards'] = train_data_in_c['inferred_rew'].loc[target_row_idx].values
	        user_seg['next_observations'] = np.array(list(train_data_in_c.loc[[i+1 for i in target_row_idx[:-1]], feature_list].values)+[[-1]*len(feature_list)])
	        user_seg['terminals'] = np.asarray([False if i < max_step-1 else True for i in range(max_step)])

	        seg_in_c.append(user_seg)
	        
	        count += 1
	    print("successfully saved {} users in cluster {}!".format(count, c))
	    
	    with open('../processed_data/train_cluster_{}.npy'.format(c), 'wb') as f:
	        np.save(f, seg_in_c)


	### save all train data
	train_data = raw_data.loc[raw_data['userID'] < 201000] # est: Spring20
	train_user_list = list(train_data['userID'].unique())
	seg = []
	count = 0 # count saved user

	for user in train_user_list:
	    # initialize local segment
	    user_seg = dict.fromkeys(['observations','actions','rewards','next_observations','terminals'])

	    target_row_idx = train_data.index[train_data['userID'] == user].tolist()
	    max_step = len(target_row_idx)
	    user_seg['observations'] = train_data.loc[target_row_idx, feature_list].values
	    user_seg['actions'] = train_data.loc[target_row_idx, ACTIONS].values
	    user_seg['rewards'] = train_data['inferred_rew'].loc[target_row_idx].values
	    user_seg['next_observations'] = np.array(list(train_data.loc[[i+1 for i in target_row_idx[:-1]], feature_list].values)+[[-1]*len(feature_list)])
	    user_seg['terminals'] = np.asarray([False if i < max_step-1 else True for i in range(max_step)])

	    seg.append(user_seg)

	    count += 1
	print("successfully saved {} users!".format(count))

	with open('../processed_data/train.npy', 'wb') as f:
	    np.save(f, seg)

if __name__ == '__main__':
	app.run(main)