from util import *
from absl import app, flags
from wandb import wandb

FLAGS = flags.FLAGS
flags.DEFINE_enum("g", "linear", ["linear", "sigmoid", "cubic", "relu"], "")
flags.DEFINE_enum(
    "h", "single_index", ["single_index", "rand_quad", "rand_proj_norm"], ""
)
flags.DEFINE_enum(
    "sigma2", "relu", ["relu", "linear", "square"], ""
)
flags.DEFINE_integer("d", None, "")
flags.DEFINE_integer("n", None, "")
flags.DEFINE_integer("m1", 1000, "")
flags.DEFINE_integer("seed", 0, "")
flags.DEFINE_integer("batch_size", 1024, "")

g_dict = {
    "linear": lambda x: x,
    "sigmoid": lambda x: nn.sigmoid(x),
    "cubic": lambda x: x**3,
    "relu": lambda x: nn.relu(x),
}

# ReLU Kernel
def ReLU_K_fn(x, y):
    norm = jnp.sqrt((x @ x) * (y @ y))
    cos_t = jnp.clip(x @ y / norm, -1, 1)
    t = jnp.arccos(cos_t)
    return norm * (jnp.sin(t) + (jnp.pi - t) * jnp.cos(t)) / jnp.pi

# Linear Kernel
def linear_K_fn(x, y):
    return x @ y / len(x)

# Square Kernel
def square_K_fn(x, y):
    return ((x @ x) * (y @ y) + 2 * (x @ y) ** 2) / len(x)

kernel_dict = {"relu": ReLU_K_fn, "linear": linear_K_fn, "square": square_K_fn}


def ridge(x, y, reg=0, rel_reg=None):
    dtype = x.dtype
    m, n = x.shape
    rcond = jnp.finfo(dtype).eps * max(n, m)
    u, s, vt = jla.svd(x, full_matrices=False)
    uTy = jnp.matmul(u.conj().T, y, precision=lax.Precision.HIGHEST)
    if rel_reg is not None:
        reg = rel_reg * s[0]

    def ridge_problem(reg):
        mask = s[0] * (s * s + reg) >= rcond * s * (s[0] * s[0] + reg)
        safe_s = jnp.where(mask, s, 1)
        s_ridge = jnp.where(mask, safe_s / (safe_s**2 + reg), 0)
        return jnp.matmul(vt.conj().T, s_ridge * uTy, precision=lax.Precision.HIGHEST)

    if jnp.ndim(reg) == 0:
        return ridge_problem(reg)
    else:
        return vmap(ridge_problem)(reg)


def main(args):
    config = {key: val.value for key, val in FLAGS._flags().items()}
    wandb.init(project="3_layer_neurips_plots", config=config)
    rng = RNG(FLAGS.seed)
    d, n, m1, batch_size = FLAGS.d, FLAGS.n, FLAGS.m1, FLAGS.batch_size
    K_fn = kernel_dict[FLAGS.sigma2]
    
    @vmap
    def to_sphere(x):
        return jnp.sqrt(d) * x / jla.norm(x)

    # print(n)
    # print(batch_size)
    batch_size = min(n, batch_size)

    g_star = g_dict[FLAGS.g]
    if FLAGS.h == "single_index":
        beta = rng.normal((d,))
        beta = beta / jla.norm(beta)
        h_star = lambda x: x @ beta
    elif FLAGS.h == "rand_quad":
        A = rng.normal((d, d))
        A = A - jnp.eye(d) * jnp.trace(A) / d
        A = A / jla.norm(A, "fro")
        h_star = lambda x: x @ A @ x
    elif FLAGS.h == "rand_proj_norm":
        U = rng.orthogonal(d)[: d // 2, :]
        A = 2 * U.T @ U - jnp.eye(d)
        A = A - jnp.eye(d) * jnp.trace(A) / d
        A = A / jla.norm(A, "fro")
        h_star = lambda x: x @ A @ x

    sigma1 = nn.relu

    def normalize(f):
        x = to_sphere(rng.normal((32768, d)))
        y = vmap(f)(x)
        return lambda x: (f(x) - y.mean()) / y.std()

    fstar = normalize(lambda x: g_star(h_star(x)))

    # Fix Test Set
    test_x = to_sphere(rng.normal((32768, d)))
    test_y = vmap(fstar)(test_x)

    # First Step
    n_batch = n // batch_size
    data_key = rng.next()

    @jit
    def h_fn(z):
        def inner(key):
            x = to_sphere(random.normal(key, (batch_size, d)))
            y = vmap(fstar)(x)
            return vmap(lambda x: K_fn(z, x))(x) @ y

        return laxmean(inner, random.split(data_key, n_batch), backend="lax")

    # Resample
    x = to_sphere(rng.normal((min(n, 32768), d)))
    y = vmap(fstar)(x)

    a = rng.rademacher((m1,), dtype=float)
    b = rng.normal((m1,))
    h1 = vmap(h_fn)(x)

    val_x = rng.normal((32768, d))
    val_y = vmap(fstar)(val_x)
    h_val = vmap(h_fn)(val_x)

    muls = jnp.geomspace(0.01, 2, 10)
    loss_list = []
    for mul in muls:
        eta1 = mul / h1.std()
        phi = vmap(lambda h: sigma1(eta1 * a * h + b) * (b > 0))(h1)

        wds = jnp.geomspace(1e-10, 1e10, 1000)
        a_wds = ridge(phi, y, rel_reg=wds)

        outputs = vmap(lambda h: a_wds @ (sigma1(eta1 * a * h + b) * (b > 0)))(h_val)
        losses = jnp.mean((outputs - val_y[:, None]) ** 2, 0)
        loss = losses.min()
        wd_idx = losses.argmin()
        wd = wds[wd_idx]
        a_final = a_wds[wd_idx]
        loss_list.append((loss, mul, eta1, wd, a_final))
    _, mul, eta1, wd, a_final = min(loss_list, key=lambda x: x[0])

    # Test Loss
    h_test = vmap(h_fn)(test_x)
    outputs = vmap(lambda h: (sigma1(eta1 * a * h + b) * (b > 0)) @ a_final)(h_test)

    table = wandb.Table(
        data=[[x.item(), y.item()] for (x, y) in zip(vmap(h_star)(test_x), h_test)],
        columns=["h_star", "h1"],
    )
    loss = jnp.mean((outputs - test_y) ** 2)
    print(loss)
    assert not jnp.isnan(loss)
    out_dict = {
        "loss": loss.item(),
        "wd": wd.item(),
        "mul": mul.item(),
        "baseline": jnp.mean(test_y**2).item(),
        "features": table,
    }
    wandb.log(out_dict)


if __name__ == "__main__":
    app.run(main)