import jax
import jax.numpy as jnp

from typing import NamedTuple


class KahanState(NamedTuple):
    value: jax.Array
    carry: jax.Array


def init_kahan_state(value: jax.Array) -> KahanState:
    return KahanState(value=value, carry=jnp.zeros_like(value))


def kahan_add(state: KahanState, value: jax.Array) -> KahanState:
    y = value - state.carry
    t = state.value + y
    carry = (t - state.value) - y
    return KahanState(value=t, carry=carry)
