import numpy as np
from jax import jit, random

key = random.PRNGKey(76777677)


def norm(X):
    X = X - X.mean(0)
    return X / X.std(0)


norm_jit = jit(norm)


def test_jax_sanity_check(allclose):
    np.random.seed(7677)
    X = random.uniform(key, shape=(10000, 10))
    assert allclose(norm(X), norm_jit(X), atol=1e-6)
