import sys

import jax.numpy as jnp

sys.path.append("../../")
from apebench import apebench  # noqa: E402

for scenario in [
    "phy_aniso_diff",
    "diff_burgers",  # Uses the two-channel version in 2D
    "diff_ks",
    "phy_kolm_flow",
    "phy_gs_type",
]:
    scene = apebench.scenarios.scenario_dict[scenario](
        num_spatial_dims=2,
    )

    ref_trj = scene.get_ref_sample_data()

    jnp.save(f"ref_sample_rollouts/{scene.get_scenario_name()}.npy", ref_trj)
