import time
import jax
import jax.numpy as jnp
from jax import random
from jax import lax, vmap, jit
from functools import partial
import numpy as np
from scipy.stats import wasserstein_distance

group_resample = vmap(lambda param, prob, key: random.choice(key, param, shape=(param.shape[0],), p=prob))

# adaptived from https://www.jeremiecoullon.com/2020/11/10/mcmcjax3ways/
@partial(jit, static_argnums=(3,4))
def transition_kernel_sde(key, param, t, drift_fn, dt):
    key, subkey = random.split(key)
    drift = drift_fn(param, t)
    param = param - dt*(-drift) - jnp.sqrt(2*dt)*random.normal(key=subkey, shape=(param.shape))
    return key, param

@partial(jit, static_argnums=(3,4))
def transition_kernel_ode(key, param, t, drift_fn, dt):
    drift = drift_fn(param, t)
    param = param - dt*(-drift)
    return key, param

def diffusion_sampler(keys, x0s, dts, cfg, score_fn):
    start_time = time.monotonic()
    results = diffusion_sampler_jax(keys, x0s, dts, cfg, score_fn)
    # print("Elapsed time for sampling: {:.1f}s".format(time.monotonic() - start_time))
    return results

@partial(jit, static_argnums=(3,4))
def diffusion_sampler_jax(keys, x0s, dts, cfg, score_fn):
    # keys: shape of [n_groups, n_particles, 2]
    # x0s: shape of [n_groups, n_particles, dim]
    # dts: shape of [n_steps], ts = cumsum(dts), the last one is the first step of reverse process
    reverse = cfg.reverse
    noise = cfg.noise
    n_groups = x0s.shape[0]
    n_steps = dts.shape[0]
    ts = jnp.cumsum(dts)
    idxs = jnp.arange(n_steps+1, 1, -1)
    fixed_arrs = jnp.stack([idxs, ts, dts], axis=1)
    fixed_arrs = jnp.flip(fixed_arrs, axis=0)
    if reverse:
        pass
    else: # Langevin
        noise = True
        ts = ts * 0
    if reverse:
        if noise:
            drift_fn = lambda z,t: 2*score_fn(z, t) + z
        else: # reverse ODE
            drift_fn = lambda z,t: score_fn(z, t) + z
    else: # Langevin
        drift_fn = lambda z,t: score_fn(z, t)
    def particle_step(i, carry):
        keys, params = carry
        idx, t, dt = fixed_arrs[i]
        idx = idx.astype(int)
        if noise:
            keys_next, params_next = vmap(vmap(lambda key, param: transition_kernel_sde(key, param, t, drift_fn, dt)))(keys, params)
        else:
            keys_next, params_next = vmap(vmap(lambda key, param: transition_kernel_ode(key, param, t, drift_fn, dt)))(keys, params)
        return keys_next, params_next
    carry = (keys, x0s)
    _, samples = lax.fori_loop(0, n_steps, particle_step, carry)
    return samples

def generate_inverse_problem_gm(key, dim, dim_y, target, scale=1, kappa_range=None, snr_range=None):
    A = random.normal(key, (dim_y, dim))
    U, D, V = jnp.linalg.svd(A, full_matrices=True)
    coordinate_mask = np.ones_like(V[0])
    coordinate_mask[len(D):] = 0
    key, subkey = random.split(key)
    if kappa_range is not None or snr_range is not None:
        assert dim_y == dim, "Only support non-degegenerate observation for setting kappa or SNR; otherwise, by definition kappa = infty and SNR=0 since lambda_min=0"
    if kappa_range is None:
        diag = jnp.sort(random.uniform(subkey, D.shape) * -1) * -1
    else:
        kappa = random.uniform(subkey, (1,), minval=kappa_range[0], maxval=kappa_range[1])[0]
        diag = jnp.sort(random.uniform(subkey, D.shape, minval=1/kappa, maxval=1.0) * -1) * -1
        diag = diag.at[0].set(1)
        diag = diag.at[-1].set(1/kappa)
    A = U @ (jnp.diag(diag) @ V[coordinate_mask == 1, :]) / scale
    # A = jnp.concatenate([U, jnp.zeros((dim_y, dim - dim_y))], axis=1) # for orthogonal A
    # A = jnp.concatenate([jnp.eye(dim_y), jnp.zeros((dim_y, dim - dim_y))], axis=1)
    key, subkey = random.split(key)
    init_sample = target.prior_dist.sample(subkey, sample_shape=())
    key, subkey = random.split(key)
    if snr_range is None:
        var_observations = random.uniform(subkey, (1,), minval=0.2, maxval=1.0)[0] * jnp.ones(len(diag)) * max(diag)**2
    else:
        var_observations = jnp.ones(len(diag)) * min(diag)**2 / random.uniform(subkey, (1,), minval=snr_range[0], maxval=snr_range[1])[0]
    std = var_observations**0.5
    # var_observations = std**2
    init_obs = A @ init_sample
    key, subkey = random.split(key)
    init_obs += random.normal(subkey, init_obs.shape) * (std)
    Sigma_y = jnp.diag(var_observations)
    target.set_observation(init_obs, A, Sigma_y)
    return init_obs, A, Sigma_y, init_sample

def sliced_wasserstein(key, sample1, sample2, n_slices=100):
    projections = random.normal(key, shape=(n_slices, sample1.shape[1]))
    projections = projections / jnp.linalg.norm(projections, axis=-1, keepdims=True)
    sample1_projected = (projections @ sample1.T)
    sample2_projected = (projections @ sample2.T)
    return np.mean([wasserstein_distance(u_values=s1, v_values=s2) for s1, s2 in zip(sample1_projected, sample2_projected)])
