from matplotlib import rc
import pandas as pd
import shap
import matplotlib.pyplot as plt
from keras.models import load_model


def plot_shap_vals(
	shapvals_filepath, 
	data_filepath, 
	indirect_output_filepath, 
	direct_output_filepath,
	predictor_filepath,
	n_instances):

	rc('font',**{'family':'sans-serif','sans-serif':['Helvetica']})
	## for Palatino and other serif fonts use:
	rc('font',**{'family':'serif','serif':['Palatino']})
	rc('text', usetex=True)
	
	shap.initjs()
		
	shap_vals_df = pd.read_csv(shapvals_filepath)
	feats_df = pd.read_csv(data_filepath)
	feat_names = ["x","x2","xSquared","y","y2","ySquared","z","z2","zSquared"]
	
	feats_df = feats_df[feat_names]
	shap_vals_df = shap_vals_df[feat_names]
	
	feats = feats_df[0:n_instances].values
	shap_vals = shap_vals_df[0:n_instances].values
	
	shap.summary_plot(shap_vals, feats, show=False, plot_type="dot", sort=False, feature_names= [r"$x$",r"$2x$",\
	                                                        r"$x^2$",r"$y$",r"$2y$", r"$y^2$",r"$c$",r"$2c$",r"$c^2$"])
	
	plt.savefig(indirect_output_filepath)
	plt.clf()


	# plot direct influence for baseline
	predictor = load_model(predictor_filepath)
	e = shap.GradientExplainer(predictor, feats, local_smoothing=0)
	shap_values, classes = e.shap_values(feats, ranked_outputs=1)
	shap.summary_plot(shap_values[0], feats, show=False, plot_type="dot", sort=False, feature_names= [r"$x$",r"$2x$",\
                                                        r"$x^2$",r"$y$",r"$2y$", r"$y^2$",r"$c$",r"$2c$",r"$c^2$"])
	plt.savefig(direct_output_filepath)


plot_shap_vals(shapvals_filepath="results/shap_values.csv", 
	data_filepath="../../data/synthetic/sum_synthetic.csv",
	indirect_output_filepath="results/indirect_influence_distributions.png",
	direct_output_filepath="results/direct_influence_distributions.png",
	predictor_filepath="models/sum_predictor.h5",
	n_instances=3000)




