import numpy as np
from scipy.special import ndtri

MACHINE_EPS = np.finfo(np.float64).eps

lfsr_list = {
    10: (115, [0, 3]),
    11: (291, [0, 2]),
    12: (172, [0, 1, 4, 6]),
    13: (267, [0, 1, 3, 4]),
    14: (332, [0, 1, 3, 5]),
    15: (388, [0, 1]),
    16: (283, [0, 2, 3, 5]),
    17: (514, [0, 3]),
    18: (698, [0, 7]),
    19: (706, [0, 1, 2, 5]),
    20: (1304, [0, 3]),
    21: (920, [0, 2]),
    22: (1336, [0, 1]),
    23: (1236, [0, 5]),
    24: (1511, [0, 1, 3, 4]),
    25: (1445, [0, 3]),
    26: (1906, [0, 1, 2, 6]),
    27: (1875, [0, 1, 2, 5]),
    28: (2573, [0, 3]),
    29: (2633, [0, 2]),
    30: (2423, [0, 1, 4, 6]),
    31: (3573, [0, 3]),
    32: (3632, [0, 2, 6, 7])
}

def random_digit_shift(U, rng):
    d = U.shape[1]
    rand_bits = rng.binomial(1, 0.5, [d, d])
    U_rand = np.zeros_like(U)
    for i in range(d):
        resid = U[:, i]
        for j in range(d):
            a_ = 2 * resid // 1
            resid = (resid - a_ / 2) * 2
            tmp = (a_ + rand_bits[i, j]) % 2
            U_rand[:, i] = U_rand[:, i] + tmp / 2**(j+1)
        U_rand[:, i] = U_rand[:, i] + resid / 2**(d)
    return U_rand

class LFSR():
    def __init__(self, m, a=None, g=None, b0=None) -> None:
        """
        m: number of bits
        a: list of nonzero polynomial coefficients, e.g. [0, 3]
        g: int, offset
        if a, g are not given, use lsfr_list
        b0: (m, ) in {0, 1} initial state, i.e. b_0; default is (1, 0, ..., 0)
        """
        if a is None:  # polynomial coefficients
            self.a = lfsr_list[m][1]
        
        if g is None:  # offset
            g = lfsr_list[m][0]
        
        self.m = m
        self.period = 2 ** m - 1
        self.g = g
        self.b = np.zeros(self.period, dtype=bool)
        if b0 is None:
            self.b[0] = 1
        else:
            self.b[:len(b0)] = b0
        for i in range(m, self.period):
            self.b[i] = sum(self.b[i-m:i][self.a]) & 1
        
        self.cur = 0
    
    def sample(self, n):
        v = np.zeros(n)
        for i in range(self.cur, self.cur + n):
            idx = (self.g * i + np.arange(self.m)) % self.period
            v[i - self.cur] = sum(1 << self.m - 1 - np.where(self.b[idx])[0])
        self.cur += n
        return v / 2**self.m

    def reset(self):
        self.cur = 0

def LMC(instance, stepsize, n, seed, cud=False):
    d = instance.d
    if np.isscalar(stepsize):
        stepsize = np.ones(n) * stepsize
    theta_t = np.zeros(d)
    rng = np.random.default_rng(seed=seed)
    traj = []
    if cud:
        lfsr = LFSR(int(np.log2(n)))
    shift = rng.random(d)
    for i in range(n):
        g = instance.U_grad(theta_t)
        if cud:
            u = lfsr.sample(1)
            u = (u + shift) % 1
            eta = ndtri(u * (1 - MACHINE_EPS) + .5 * MACHINE_EPS)
        else:
            eta = rng.standard_normal(d)
        lr = stepsize[i]
        update = -lr * g + np.sqrt(2 * lr) * eta
        theta_t = theta_t + update
        traj.append(theta_t)
        assert ~np.any(abs(theta_t) == np.Inf), "Inf detected"
        assert ~np.any(np.isnan(theta_t)), "NaN detected"
    traj = np.array(traj)
    return traj

def SGLD(instance, stepsize, n, seed, M=1, cud=False):
    d = instance.d
    if np.isscalar(stepsize):
        stepsize = np.ones(n) * stepsize
    rng = np.random.default_rng(seed=seed)
    theta_t = rng.standard_normal(d)
    traj = []
    if cud:
        lfsr = LFSR(int(np.log2(n)))
    shift = rng.random(d)
    for i in range(n):
        idx = rng.choice(instance.N, M, replace=False)
        g = instance.U_grad(theta_t, idx)
        if cud:
            u = lfsr.sample(1)
            u = (u + shift) % 1
            eta = ndtri(u * (1 - MACHINE_EPS) + .5 * MACHINE_EPS)
        else:
            eta = rng.standard_normal(d)
        lr = stepsize[i]
        update = -lr * g + np.sqrt(2 * lr) * eta
        theta_t = theta_t + update
        traj.append(theta_t)
        assert ~np.any(abs(theta_t) == np.Inf), "Inf detected"
        assert ~np.any(np.isnan(theta_t)), "NaN detected"
    traj = np.array(traj)
    return traj

def stepsize_schedule(maxit, start=0.01, end=0.0001, gamma=0.33):
    b = (maxit - 1) / ((start / end)**(1/gamma) - 1)
    a = start * b**gamma
    return a * (b + np.arange(maxit))**(-gamma)
