def load_weights(weights,filename):
    import flax
    import pickle
    pkl_file=pickle.load(open(filename,"rb"))
    tained_weights=flax.serialization.from_bytes(target=weights,encoded_bytes=pkl_file)
    return tained_weights

def save_weights(weights, filename):
    import flax
    import pickle
    bytes_output=flax.serialization.to_bytes(target=weights)
    pickle.dump(bytes_output,open(filename,"wb"))


class DemosTemplate:
    """
    Takes a template for the full demo and provides methods for filling in blanks.
    The format is as follows:
    [INPUT], [OUTPUT]

    """

    def __init__(self, template, delimiter='\n\n'):
        self.template = template
        self.delimiter = delimiter

    def fill(self, data):
        """
        Fills in the template with the given values. Data is a tuple of lists.
        """
        demos = ''
        for i, (input_, output_) in enumerate(zip(*data)):
            demos += self.template.replace('[INPUT]', input_).replace(
                '[OUTPUT]', output_)

            if i != len(data[0]) - 1:
                demos += self.delimiter

        return demos


class EvalTemplate:
    """
    Takes a prompt template and provides methods for filling in blanks.
    The format is as follows:
    [PROMPT] is where the prompt will be inserted.
    [full_DEMO] is where the full demo will be inserted.
    [INPUT] is where the input to the first demo will be inserted.
    [OUTPUT] is where the output from the first demo will be inserted.
    """

    def __init__(self, template):
        self.template = template

    def fill(self, prompt='', full_demo='', input='', output=''):
        """
        Fills in the template with the given values.
        """
        return self.template.replace('[PROMPT]', prompt).replace(
            '[full_DEMO]', full_demo).replace('[INPUT]', input).replace('[OUTPUT]', output)


def pig_latin_translator(original_sentence, end1='yay', end2='ay'):
    words = original_sentence.split()
    vowels = ["a", "e", "i", "o", "u", "A", "E", "I", "O", "U"]
    output = ""
    for word in words:
        # check if the word is capitalized
        capitalized = word[0].isupper()
        punctuation = ""
        if word[-1] in [".", ",", "!", "?"]:
            punctuation = word[-1]
            word = word[:-1]
        if word[0] in vowels:
            if word[-1] == end1[0]:
                translated_word = word + end1[1:]
            else:
                translated_word = word + f"{end1}"
        else:
            start = 0
            for i, letter in enumerate(word):
                if letter in vowels:
                    break
                else:
                    # if letter == "y":
                    #     break
                    # else:
                    start += 1
            translated_word = word[start:] + word[:start] + f"{end2}"
        if capitalized:
            translated_word = translated_word.capitalize()
        translated_word += punctuation
        output += translated_word + " "
    return output.strip()


# Plotting parameters
def set_up_plotting():

    # import seaborn as sns; sns.set_theme()
    import matplotlib.pyplot as plt

    LABEL_FONTSIZE = 24
    MARKER_SIZE = 10
    AXIS_FONTSIZE = 26
    TITLE_FONTSIZE= 26
    LINEWIDTH = 6

    plt.rcParams["font.family"] = "serif"
    plt.rcParams["font.serif"] = ["Times New Roman"] + plt.rcParams["font.serif"]

    # plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
    plt.rc('figure', titlesize=TITLE_FONTSIZE)     # fontsize of the axes title
    plt.rc('axes', titlesize=TITLE_FONTSIZE)     # fontsize of the axes title
    plt.rc('axes', labelsize=AXIS_FONTSIZE)    # fontsize of the x and y labels
    plt.rc('xtick', labelsize=LABEL_FONTSIZE)    # fontsize of the tick labels
    plt.rc('ytick', labelsize=LABEL_FONTSIZE)    # fontsize of the tick labels
    plt.rc('legend', fontsize=LABEL_FONTSIZE)    # legend fontsize
    plt.rc('lines', markersize=MARKER_SIZE)  # fontsize of the figure title
    plt.rc('lines', linewidth=LINEWIDTH)  # fontsize of the figure title
    plt.rc('font', weight='bold') # set bold fonts

    return plt


from math import ceil
import math
import numpy as np


def get_U_hat(d, norm_ord=1):
    U = np.zeros((d-1, d))

    for i in range(d-1):
        for j in range(d):
            for s in range(d-1):
                if i == s and j == i+1:
                    U[i,j] = -(s+1)
                elif i <= s and i >= j:
                    U[i,j] = 1   
      
    from numpy.linalg import norm
    return U / norm(U, axis=1, ord=norm_ord)[:,None]


def PolarToCartesian(r, psis):
    x = np.zeros(len(psis) + 1) # shape of x should be d -1, and since len(psis) = p-2, initialize x = np.zeros(len(psis)+1) 

    for i in range(len(x)):
        x[i] = r
        for j in range(i):
            x[i] *= np.sin(psis[j])
        
        if i < len(psis):
            x[i] *= np.cos(psis[i])

    return x


from scipy.integrate import quad
from scipy.special import beta
from scipy.stats.qmc import Sobol
from scipy.optimize import root_scalar


def f_last(psi):
    return 0.5 * np.pi

def f_mid(psi, j, d):
    assert 1 <= j < d-2
    return 1./beta(0.5*(d-j-1), 0.5) * np.power(np.sin(psi), d-j-2)

def cdf_F(psi, j, d):
    assert j <= d-2, "j must be <= d-2"

    if j == d - 2:
        return quad(f_last, 0, psi)
    else:
        return quad(f_mid, 0, psi, args=(j,d))


def SobolPermutations(num_samples, dimension, seed=3244, verbose=True):
    '''
    num_samples: the number of permutations to sample
    dimension: the number of players, i.e., the dimension of the permutation

    '''

    sampled_permutations = []
    U_hat = get_U_hat(dimension)

    sampler = Sobol(d=dimension-2, scramble=True, seed=seed)
    sobol_sequence = sampler.random_base2(m=ceil(np.log2(num_samples)))
    for sobol_point in sobol_sequence:
        psis = np.zeros(dimension-2)
        for j in range(dimension-2):

            target = sobol_point[j]        
            sol = root_scalar(lambda x, *args:cdf_F(x, *args)[0] - target, args=(j+1, dimension), bracket=(0, np.pi))
            psis[j] = sol.root

        y = PolarToCartesian(1, psis)        
        # print(f'shape of y is {y.shape}, shape of U_hat is {U_hat.shape}')
        z = U_hat.T @ y
        # print(f'shape of z is {z.shape}')
        sampled_permutations.append(np.argsort(z))

    if verbose and num_samples != len(sampled_permutations):
        print(f'requested num_samples is {num_samples}, number of sampled permutations is {len(sampled_permutations)}, returning the first {num_samples} sampled permutations.')
        print('It is advised to sample a number that is an exact power of 2, of permutations to enjoy the theoretical properties of Sobol sequence.')

    return sampled_permutations[:num_samples]