import numpy as np
import pandas as pd
import re
import scipy
import copy

import warnings
from random import random
from math import log, ceil
from time import time, ctime
import csv
import json
import ConfigSpace as CS
import dask
import statsmodels.api as sm
import scipy.stats as sps
import warnings


def min_max_scaling(values):
	min_val = min(values)
	max_val = max(values)
	range_val = max_val - min_val
	scaled_values = [(x - min_val) / range_val if range_val != 0 else 0 for x in values]
	return scaled_values

class BOHB:
	def __init__(self, configspace, get_params_function, try_params_function, max_iter=81, min_budget=1,
				 eta=3, best_percent=0.15, random_percent=1/3, n_samples=64,
				 bw_factor=3, min_bandwidth=1e-3, n_proc=1, skip_first = 0):
		self.get_params = get_params_function
		self.try_params = try_params_function
		self.configspace = configspace
		hps = self.configspace.get_hyperparameters()
		self.kde_vartypes = ""
		self.vartypes = []
		for h in hps:
			print(f"{h=}")
			if hasattr(h, 'sequence'):
				if isinstance(h, CS.hyperparameters.OrdinalHyperparameter):
					self.kde_vartypes += 'o'
					self.vartypes += [len(h.sequence)]
				else:
					raise RuntimeError('This version on BOHB does not support ordinal hyperparameters. Please encode %s as an integer parameter!'%(h.name))
			elif hasattr(h, 'choices'):
				self.kde_vartypes += 'u'
				self.vartypes +=[ len(h.choices)]
			else:
				self.kde_vartypes += 'c'
				self.vartypes +=[0]
		self.vartypes = np.array(self.vartypes, dtype=int)

		# store precomputed probs for the categorical parameters
		self.cat_probs = []
	
		self.max_iter = max_iter  	# maximum iterations per configuration
		self.eta = eta			# defines configuration downsampling rate (default = 3)
		self.skip_first = skip_first
		
		self.logeta = lambda x: log( x ) / log( self.eta )
		self.s_max = int( self.logeta( self.max_iter ))
		self.B = ( self.s_max + 1 ) * self.max_iter
		
		self.best_percent = best_percent
		self.random_percent = random_percent
		self.n_samples = n_samples
		self.min_bandwidth = min_bandwidth
		self.bw_factor = bw_factor
		self.n_proc = n_proc

		self.kde_good = None
		self.kde_bad = None
		self.samples = np.array([])

		self.counter = 0
		self.fixed_config_dict = dict()
	
	def __str__(self):
		return f"BOHB_Max_iter_{self.max_iter}_eta_{self.eta}"
	
	def run_fixed_configs(self, criteria = 'valid_accuracy', direction = None):
		# clear results
		results = []
		final_results = []

		# dealing with special criteria
		if criteria.startswith('wgh'):
			# Find all matches of integers in criteria
			pattern = r'\d+'
			matches = re.findall(pattern, criteria)
			# Extract the first two numbers
			if len(matches) >= 2:
				wgh1 = float(matches[0]) * 0.1
				wgh2 = float(matches[1]) * 0.1
				print(f"wgh1 = {wgh1}, wgh2 = {wgh2}")
			else:
				raise ValueError("Not enough numbers found in criteria.")
	
		for s in reversed(range(self.s_max + 1)):

			# initial number of configurations
			n = int( ceil( self.B / self.max_iter / ( s + 1 ) * self.eta ** s ))	
			
			# initial number of iterations per config
			r = self.max_iter * self.eta ** ( -s )		

			# if round zero, record configurations
					
			self.kde_good = None
			self.kde_bad = None
			self.samples = np.array([])

			for i in range(self.skip_first, s+1):
				# Run each of the n configs for <iterations> 
				# and keep best (n_configs / eta) configurations
				
				n_configs = n * self.eta ** ( -i + self.skip_first )
				n_iterations = r * self.eta ** ( i )
				
				print( "\n*** {} configurations x {:.1f} iterations each".format( 
					n_configs, n_iterations ))
					
				criterias = []
				early_stops = []
				samples = []
				losses = []
				for j in range(int(n_configs)):
					self.counter += 1
					config = self.get_sample(s, n, j)
					result = self.try_params(n_iterations, config, criteria)
					if isinstance(config, (int, np.int64, np.int32)) or isinstance(config, str):
						samples.append(config)
					else:
						print(f"{config.get_array()=}")
						samples.append(config.get_array())
					
					assert( type( result ) == dict )
					assert( criteria in result )
					assert( 'time' in result )
					assert( 'test_accuracy' in result )
					
					seconds = result['time']
					print( "\n{} seconds.".format( seconds ))

					crt_val = result[criteria]
					criterias.append(crt_val)
					
					early_stop = result.get( 'early_stop', False )
					early_stops.append( early_stop )
					
					result['s'] = s
					result['counter'] = self.counter
					result['params'] = config
					result['n_iteration'] = n_iterations
					
					results.append( result )
					
					# last round of successive halving
					if i == s:
						final_results.append(result)
				
				# select a number of best configurations for the next loop
				# filter out early stops, if any
				# print(f"criteria = {criteria}")
				if criteria.startswith('wgh'):
					# 1. Normalize both the lcr and second value
					lcrs = [cta[0] for cta in criterias]
					rvals = [cta[1] for cta in criterias]
					normed_lcrs = min_max_scaling(lcrs)
					normed_rvals = min_max_scaling(rvals)

					# 2. Combine with weights
					# 3. Replace the criterias, results, and final_results with new value
					for cta_idx in range(len(criterias)):
						i = cta_idx + len(results) - len(criterias)
						wgh_val = wgh1 * normed_lcrs[cta_idx] + wgh2 * normed_rvals[cta_idx]
						results[i][criteria] = wgh_val
						criterias[cta_idx] = wgh_val

						if i >= len(results) - len(final_results):
							final_results[i - len(results) + len(final_results)][criteria] = wgh_val

				elif criteria.startswith('dyn_win'):
					sigs = np.array([cta[1] for cta in criterias])
					median_sig = np.median(sigs)
					mean_sig = np.mean(sigs)

					if mean_sig <= median_sig:	# most sigs are large
						if direction == 'Max':	# Accuracy
							for cta_idx in range(len(criterias)):
								i = cta_idx + len(results) - len(criterias)
								val = criterias[cta_idx][0] + criterias[cta_idx][1]	# mu + sig
								results[i][criteria] = val
								criterias[cta_idx] = val

								if i >= len(results) - len(final_results):
									final_results[i - len(results) + len(final_results)][criteria] = val
						elif direction == 'Min':	# Loss
							for cta_idx in range(len(criterias)):
								i = cta_idx + len(results) - len(criterias)
								val = criterias[cta_idx][0] - criterias[cta_idx][1]	# mu - sig
								results[i][criteria] = val
								criterias[cta_idx] = val

								if i >= len(results) - len(final_results):
									final_results[i - len(results) + len(final_results)][criteria] = val
					else:	# most sigs are small
						print(f"Epoch = {n_iterations} -- Most sigs are SMALL !!!!")
						for cta_idx in range(len(criterias)):
							i = cta_idx + len(results) - len(criterias)
							val = criterias[cta_idx][0]	# mu
							results[i][criteria] = val
							criterias[cta_idx] = val

							if i >= len(results) - len(final_results):
								final_results[i - len(results) + len(final_results)][criteria] = val
				
				indices = np.argsort( criterias )
				if direction == 'Max':	# maximum
					indices = indices[::-1]

				self.samples = np.array(samples)[indices[:int( n_configs / self.eta )]]
				n_good = int(np.ceil(self.best_percent * len(samples)))
				if n_good > len(self.kde_vartypes) + 2:
					good_data = self.impute_conditional_data(np.array(samples)[indices[:n_good]])
					bad_data = self.impute_conditional_data(np.array(samples)[indices[n_good:]])
					
					# quick rule of thumb
					bw_estimation = 'normal_reference'
					self.kde_good = sm.nonparametric.KDEMultivariate(data=good_data, var_type=self.kde_vartypes, bw=bw_estimation)
					self.kde_bad = sm.nonparametric.KDEMultivariate(data=bad_data, var_type=self.kde_vartypes, bw=bw_estimation)
					self.kde_bad.bw = np.clip(
						self.kde_bad.bw, self.min_bandwidth, None)
					self.kde_good.bw = np.clip(
						self.kde_good.bw, self.min_bandwidth, None)
					
		# rank final result
		if direction == 'Max':	# maximum
			ranked = sorted(final_results, key=lambda x: x[criteria], reverse=True)
		elif direction == 'Min':
			ranked = sorted(final_results, key=lambda x: x[criteria])
		else:
			raise ValueError(f"Invalid direction '{direction}'.")
		
		print(f"ranked = {ranked}")
		print(" ****** the best one ***** ")
		print(ranked[0])
		# append the best one to the last of rst
		results.append(ranked[0])

		return results

	def get_sample(self, s, n, j):
		sample = None
		# If no model is available, sample from prior
		# also mix in a fraction of random configs
		if self.kde_good is None:
			if len(self.samples):
				idx = np.random.randint(0, len(self.samples))
				sample = self.samples[idx]
				self.samples = np.delete(self.samples, idx, axis=0)

				if self.samples.ndim == 1:	# integer, str, etc.
					return sample

				for i, hp_value in enumerate(sample):	# CS
					if isinstance(
						self.configspace.get_hyperparameter(
							self.configspace.get_hyperparameter_by_idx(i)
						),
						CS.hyperparameters.CategoricalHyperparameter
					):
						sample[i] = int(np.rint(sample[i]))
				sample = CS.Configuration(self.configspace, vector=sample)
				return sample
			else:
				if f's_{s}' in self.fixed_config_dict:
					return self.fixed_config_dict[f's_{s}'][j]
				else:
					T = [ self.get_params() for i in range( n )] 
					self.fixed_config_dict[f's_{s}'] = T
					return T[0]
		
		# Sample from the good data
		best = np.inf
		best_vector = None
		l = self.kde_good.pdf
		g = self.kde_bad.pdf
		minimize_me = lambda x: max(1e-32, g(x))/max(l(x),1e-32)

		for _ in range(self.n_samples):
			idx = np.random.randint(0, len(self.kde_good.data))
			datum = self.kde_good.data[idx]
			vector = []
			for m, bw, t in zip(datum, self.kde_good.bw, self.vartypes):
				bw = max(bw, self.min_bandwidth)
				if t == 0:
					bw = self.bw_factor * bw
					try:
						vector.append(sps.truncnorm.rvs(-m/bw,(1-m)/bw, loc=m, scale=bw))
					except:
						warnings.warn("Truncated Normal failed for:\ndatum=%s\nbandwidth=%s\nfor entry with value %s"%(datum, self.kde_good.bw, m))
						warnings.warn("data in the KDE:\n%s"%self.kde_good.data)
				else:
					if np.random.rand() < (1-bw):
						vector.append(int(m))
					else:
						vector.append(np.random.randint(t))
			val = minimize_me(vector)

			if not np.isfinite(val):
				warnings.warn('sampled vector: %s has EI value %s'%(vector, val))
				warnings.warn("data in the KDEs:\n%s\n%s"%(self.kde_good.data, self.kde_bad.data))
				warnings.warn("bandwidth of the KDEs:\n%s\n%s"%(self.kde_good.bw, self.kde_bad.bw))
				warnings.warn("l(x) = %s"%(l(vector)))
				warnings.warn("g(x) = %s"%(g(vector)))

				# right now, this happens because a KDE does not contain all values for a categorical parameter
				# this cannot be fixed with the statsmodels KDE, so for now, we are just going to evaluate this one
				# if the good_kde has a finite value, i.e. there is no config with that value in the bad kde, so it shouldn't be terrible.
				if np.isfinite(l(vector)):
					best_vector = vector
					break

			if val < best:
				best = val
				best_vector = vector

		if best_vector is None:
			warnings.warn("Sampling based optimization with %i samples failed -> using random configuration"%self.n_samples)
			if len(self.kde_vartypes) == 1:
				print(f"sample again")
				sample = self.get_params()
			else:
				sample = self.configspace.sample_configuration()
		else:
			if np.array(best_vector).ndim == 1 and np.array(best_vector).size == 1:
				print(f"@@@@@ new sample = {best_vector}")
				return best_vector[0]
			else:
				print(f"best_vector: {best_vector}, {best}, {l(best_vector)}, {g(best_vector)}")
				for i, hp_value in enumerate(best_vector):
					if isinstance(
						self.configspace.get_hyperparameter(
							self.configspace.get_hyperparameter_by_idx(i)
						),
						CS.hyperparameters.CategoricalHyperparameter
					):
						best_vector[i] = int(np.rint(best_vector[i]))
				sample = CS.Configuration(self.configspace, vector=best_vector)
				print(f"@@@@@ new sample = {sample}")
		return sample
	
	def impute_conditional_data(self, array):

		return_array = np.empty_like(array)

		for i in range(array.shape[0]):
			datum = np.copy(array[i])
			nan_indices = np.argwhere(np.isnan(datum)).flatten()

			while (np.any(nan_indices)):
				nan_idx = nan_indices[0]
				valid_indices = np.argwhere(np.isfinite(array[:,nan_idx])).flatten()

				if len(valid_indices) > 0:
					# pick one of them at random and overwrite all NaN values
					row_idx = np.random.choice(valid_indices)
					datum[nan_indices] = array[row_idx, nan_indices]

				else:
					# no good point in the data has this value activated, so fill it with a valid but random value
					t = self.vartypes[nan_idx]
					if t == 0:
						datum[nan_idx] = np.random.rand()
					else:
						datum[nan_idx] = np.random.randint(t)

				nan_indices = np.argwhere(np.isnan(datum)).flatten()
			if self.samples.ndim == 1:	# integer, str, etc.
				return_array[i] = datum
			else:
				return_array[i,:] = datum
		return(return_array)
	
	def get_fixed_config_dict(self, config_space):
		if not self.fixed_config_dict:
			raise ValueError("config_dict is empty.")
		# print(f'fixed_config_dict = {self.fixed_config_dict}')
		serialized_config_dict = dict()
		for s in reversed( range(self.skip_first, self.s_max + 1 )):
			T = []
			for config in self.fixed_config_dict[f's_{s}']:
				T.append(config.get_dictionary())
			serialized_config_dict[f's_{s}'] = T
		return serialized_config_dict
	
	def load_fixed_config_dict(self, file_path, config_space):
		with open(file_path, "r") as json_file:
			loaded_configuration_dict = json.load(json_file)
		self.fixed_config_dict = dict()
		for s in reversed( range(self.skip_first, self.s_max + 1 )):
			T = []
			for config in loaded_configuration_dict[f's_{s}']:
				T.append(CS.Configuration(config_space, values=config))
			self.fixed_config_dict[f's_{s}'] = T
			
	def get_fixed_config_dict_lcbench(self):
		if not self.fixed_config_dict:
			raise ValueError("config_dict is empty.")

		return self.fixed_config_dict
	
	def load_fixed_config_dict_lcbench(self, file_path):
		self.fixed_config_dict = None
		with open(file_path, "r") as json_file:
			self.fixed_config_dict = json.load(json_file)
		
		if self.fixed_config_dict == None:
			raise ValueError("Error in loading configuration dictionary.")
		
	def record_to_csv(self, results, record_file='./record.csv'):
		df = pd.DataFrame(results)
		df.to_csv(record_file, index=False)

