'''
Script for creating data for Figure 6.
Runs the DP subsampled MCMC method for 20 times.
'''

from collections import OrderedDict as od
from matplotlib import pyplot as plt
import numpy as np
import numpy.random as npr
import pandas as pd
from scipy import stats
import sys, pickle, datetime
from barker_mog import run_dp_Barker

import X_corr


def fig6_data():
	'''
	Model:
		theta ~ N(0,diag(sigma_1^2, sigma_2^2))
		x_i ~ .5*N(theta_1, sigma_x^2) + .5*N(theta_1+theta_2, sigma_x^2)
		use fixed values
		sigma_1^2 = 10, sigma_2^2 = 1, sigma_x^2 = 2
		theta_1 = 0, theta_2 = 1
	'''
	
	####################################
	# draw samples
	N = 1000000
	data = np.zeros((N))
	thetas = [0,1] # theta_1, theta_2
	
	for i in range(N):
		if npr.rand()>0.5:
			data[i] = np.random.normal(thetas[0], np.sqrt(2))
		else:
			data[i] = np.random.normal(np.sum(thetas), np.sqrt(2))

	####################################
	### Initialize with DPVI
	from moments_accountant import ma
	from dpvi import dpvi_mix_gaus
	k_dpvi = 2
	batch_size_dpvi = 100
	noise_sigma_dpvi = 10
	T_dpvi = 100
	clip_threshold = 1.0
	learning_rate = 0.001
	params_0_dpvi = 2*npr.randn(4)
	dpvi_delta = 1e-6
	dpvi_eps_cost = ma(noise_sigma_dpvi, batch_size_dpvi/N, T_dpvi, dpvi_delta)
	print('Initializing parameters with DPVI, privacy_cost : {}'.format(dpvi_eps_cost))
	params, params_0 =  dpvi_mix_gaus(data, k_dpvi, params_0_dpvi,  T_dpvi, batch_size_dpvi,\
									clip_threshold, noise_sigma_dpvi, learning_rate)
	print('Initialization done!')
	theta_from_dpvi = params[:k_dpvi]
	
	####################################
	# MCMC for posterior estimation
	## Set path to save results
	fname = './results/dp_mcmc_results_temped_n_runs_'

	batch_size = 1000
	T = 5000 # number of steps to run
	burn_in = 0
	n_runs = 20
	prop_var = 0.01
	temp_scale = 100/N
	#temp_scale = 1.0

	# exact DP Barker: exact if correction to logistic is exact, and no clipping
	privacy_pars = {}
	privacy_pars['noise_scale'] = np.sqrt(2.0)
	privacy_pars['clip'] = [0, 0.99*np.sqrt(batch_size)/temp_scale/N]
	# Parameters for X_corr
	x_max = 10 # Sets the bound to grid [-x_max, x_max]
	n_points = 1000 # Number of grid points used
	normal_variance = np.round(privacy_pars['noise_scale']**2) # C in the paper
	# Set X_corr filename and try to read parameters from file
	x_corr_filename =  './X_corr/X_corr_{}_{}_{}_torch.pickle'.format(n_points,x_max,normal_variance)
	try:
		# Try to read X_corr-MoG parameters from file
		print('reading X_corr params from file')
		xcorr_params = pickle.load(open(x_corr_filename, 'rb'))
	except:
		# Learn X_corr-MoG parameters for given normal variance 
		print('no existing file found, creating new x_corr parameter & saving to file')
		xcorr_params = X_corr.get_x_corr_params(x_max=x_max, n_points=n_points,\
				C=normal_variance, path_to_file=x_corr_filename)
	# Run the DP-MCMC for multiple (n_runs) times
	theta_chains = []
	n_accepteds = []
	sample_varss = []
	clip_counts = []
	for i in range(n_runs):
		theta_chain, n_accepted, sample_vars, clip_count = run_dp_Barker(T, prop_var,\
						theta_from_dpvi, data, privacy_pars, xcorr_params, n_points,\
						batch_size=batch_size, temp_scale=temp_scale,\
						count_clipped=True)
		theta_chains.append(theta_chain)
		n_accepteds.append(n_accepted)
		sample_varss.append(sample_vars)
		clip_counts.append(clip_count)
	
	theta_chain = np.array(theta_chains)
	n_accepted = np.array(n_accepteds)
	sample_vars = np.array(sample_varss)
	clip_counts = np.array(clip_counts)
	# Save results to a pickle file
	date = datetime.date.today()
	fname += str(date.day)+'_'+str(date.month)+'.p'
	fname_extension = 0
	while True:
		try: 
			f = open(fname, 'rb')
			f.close()
			if fname_extension==0:
				fname = fname[:-2]+'({}).p'.format(fname_extension)
			else:
				fname = fname[:-len('({}).p'.format(fname_extension))]+'({}).p'.format(fname_extension)
			fname_extension += 1
		except:
			f = open(fname, 'wb')
			print('Wrote results to {}'.format(fname))
			break

	dpvi_params = {'eps': dpvi_eps_cost, 'delta': dpvi_delta}
	dp_mcmc_params = {'N':N, 'B': batch_size, 'T': T, 'temp': temp_scale,\
			 'prop_var':prop_var, 'clip_counts':clip_counts, 'n_runs':n_runs}
	to_pickle = [dpvi_params, dp_mcmc_params, theta_chain, privacy_pars]
	pickle.dump(to_pickle, f)
	f.close()
	return fname
