import exp_utils as PQ
import torch
import torch.nn as nn
from torch.nn.functional import relu, softplus
import numpy as np
import pickle
import pytorch_lightning as pl
from loguru import logger

from rl_utils.runner import merge_episode_stats, RunnerX, EpisodeReturn, ExtractLastInfo, RunnerWithModel
from copy import deepcopy
from safe import *
from rl_utils import SACTrainer
from safe.debugger import Debugger

import safe.envs
from safe.envs import make_env
from rl_utils import MLP
import rl_utils


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class FLAGS(PQ.BaseFLAGS):
    _strict = False

    class model(PQ.BaseFLAGS):
        type = 'learned'
        n_ensemble = 5
        n_elites = 0
        frozen = False
        train = TransitionModel.FLAGS

    class ckpt(PQ.BaseFLAGS):
        L = ''
        policy = ''
        models = ''
        s = ''
        x_vs_L = ''

        buf = ''
        safe_invariant = ''
        model_trainers = ''

    class fix(PQ.BaseFLAGS):
        L = False
        policy = False
        normalizer = True
        model = False

    class h(PQ.BaseFLAGS):
        type = 'learned'

    fixed_h = HandCraftBarrierSwing.FLAGS
    env = safe.envs.FLAGS
    SAC = SafeSACTrainer2.FLAGS
    lyapunov = Barrier.FLAGS
    opt_s = SLangevinOptimizer.FLAGS
    opt_L = LOptimizer.FLAGS
    crabs = CRABS.FLAGS
    uncertainty = EnsembleUncertainty.FLAGS

    n_iters = 500000
    n_plot_iters = 10000
    n_eval_iters = 1000
    n_save_iters = 10000
    n_pretrain_s_iters = 10000
    task = 'train'
    win_streak = 5
    streak_threshold = 0.0
    bf_safe_policy = False


rng = np.random.RandomState(1)


def sample_state_space(n):
    theta = rng.uniform(-1, 1, size=(n, 1)) * np.pi / 2
    vel = rng.uniform(-1, 1, size=(n, 1))
    return torch.tensor(np.hstack([theta, vel]), dtype=torch.float32)


class DetMLPPolicy(MLP, rl_utils.DetNetPolicy):
    pass


class MLPQFn(MLP, rl_utils.NetQFn):
    pass


class TanhGaussianMLPPolicy(rl_utils.policy.TanhGaussianPolicy, MLP, rl_utils.NetPolicy):
    pass


def bf_optimize(policy, s, U_pi):
    breakpoint()
    opt = torch.optim.Adam(policy.parameters())
    mask = torch.ones(len(s), device=s.device)

    s = s.detach().requires_grad_()
    for i in range(100000):
        obj = relu(U_pi(s)) * mask
        if i % 1000 == 0:
            print(obj.mean(), obj.max())
            breakpoint()

        loss = obj.mean()
        opt.zero_grad()
        loss.backward()
        opt.step()


def bump(t, mod):  # find min x > t such that x % mod == 0
    return (t // mod + 1) * mod


def main():
    import logging
    PQ.init(FLAGS)
    logging.getLogger('lightning').setLevel(0)
    from pytorch_lightning.loggers import WandbLogger
    wandb_logger = WandbLogger(save_dir=PQ.log_dir, name=PQ.log_dir.name)
    wandb_logger.experiment.config.update(FLAGS.to_dict())

    PQ.log.info(f"wandb url = {wandb_logger.experiment.url}")

    env = make_env()
    s0 = torch.tensor(env.reset(), device=device, dtype=torch.float32)
    dim_state = env.observation_space.shape[0]
    dim_action = env.action_space.shape[0]

    normalizer = Normalizer(dim_state, clip=1000).to(device)

    # define and update initial state box
    state_box = StateBox([dim_state], device)
    state_box.reset(s0)

    if FLAGS.ckpt.buf:
        buf_real = torch.load(FLAGS.ckpt.buf)
        logger.warning(f"load model buffer from {FLAGS.ckpt.buf}")
        # with open(FLAGS.ckpt.buf, 'rb') as f:
        #     buf_real = pickle.load(f)
        #     logger.warning(f"load model buffer from {FLAGS.ckpt.buf}")
    else:
        buf_real = rl_utils.TorchReplayBuffer(env, max_buf_size=1000_000)

    buf_dev = rl_utils.TorchReplayBuffer(env, max_buf_size=10_000)
    # policy = DetMLPPolicy([dim_state, 64, 64, dim_action], auto_squeeze=False, output_activation=nn.Tanh).to(device)
    # mean_policy = policy
    policy = TanhGaussianMLPPolicy([dim_state, 64, 64, dim_action * 2]).to(device)
    # unsafe_policy = TanhGaussianMLPPolicy([dim_state, 64, 64, dim_action * 2]).to(device)
    mean_policy = rl_utils.policy.MeanPolicy(policy)

    if FLAGS.env.id == 'MySafexp-PointGoal1-v1':
        logger.warning("use DomainModel")
        make_model = lambda: DomainModel(env, env.hazards_pos, env.vases_pos, env.goal_pos)
    elif FLAGS.model.type == 'GatedTransitionModel':
        make_model = lambda i: \
            GatedTransitionModel(dim_state, normalizer, [dim_state + dim_action, 256, 256, 256, 256, dim_state * 2],
                            name=f'model-{i}')
    else:
        make_model = lambda i:  \
            TransitionModel(dim_state, normalizer, [dim_state + dim_action, 256, 256, 256, 256, dim_state * 2],
                            name=f'model-{i}')
        # make_model = lambda i: StableDynamics(
        #     TransitionModel(dim_state, normalizer, [dim_state + dim_action, 256, 256, 256, 256, dim_state * 2]),
        #     dim_state, 0.01, buf=buf_real, buf_dev=None, name=f'model_{i}').to(device)
    ensemble = EnsembleModel([make_model(i) for i in range(FLAGS.model.n_ensemble)])
    # model_trainers = [ModelTrainer(model, buf_real, buf_dev, device=device, name=f'model/{i}')
    #                   for i, model in enumerate(ensemble.models)]
    model_trainer = pl.Trainer(
        max_epochs=0, gpus=1, auto_select_gpus=True, default_root_dir=PQ.log_dir,
        progress_bar_refresh_rate=0, checkpoint_callback=False, logger=wandb_logger)

    horizon = env.spec.max_episode_steps
    make_stats = [lambda: ExtractLastInfo('episode.unsafe'), lambda: EpisodeReturn()]
    runners = {
        'explore': RunnerX(make_env, 1, make_stats, device=device),
        'evaluate': RunnerX(make_env, 1, make_stats, device=device),
        'test': RunnerX(make_env, 1, make_stats, device=device),
    }

    if FLAGS.h.type == 'FakeL2':
        barrier = FakeL2().to(device)
    elif FLAGS.h.type == 'HandCraftBarrierSwing':
        barrier = HandCraftBarrierSwing().to(device)
    else:
        barrier = Barrier(nn.Sequential(normalizer, MLP([dim_state, 256, 256, 1])), env.barrier_fn, s0).to(device)

    buf_out = Buffer(11_000, [dim_state], device=device)
    if FLAGS.model.type == 'oracle':
        model = OracleModel(make_env)
        U = OracleUncertainty(model, barrier)
    elif FLAGS.model.type in ['learned', 'GatedTransitionModel']:
        model = ensemble
        U = EnsembleUncertainty(ensemble, barrier)
        if FLAGS.model.frozen:
            ensemble.requires_grad_(False)
            logger.warning(f"models are frozen!")
    else:
        assert 0
    crabs = CRABS(barrier, U, mean_policy, env.barrier_fn, normalizer)
    # obj_eval = ObjEvaluator(crabs)

    # set requires_grad to False so that SafeAlgo won't optimize them.
    if FLAGS.fix.L:
        barrier.requires_grad_(False)
    if FLAGS.fix.policy:
        policy.requires_grad_(False)
    if FLAGS.fix.model:
        ensemble.requires_grad_(False)

    # policy optimization
    fake_model_fns = {
        'transition': model,
        'reset': lambda: s0,
        # 'done': lambda s, a, sp: ~env.is_state_safe(sp),
        'done': lambda s, a, sp: torch.zeros(len(s), device=s.device, dtype=torch.bool),
        'reward': lambda s, a, sp: torch.where(env.is_state_safe(sp), env.reward_fn(s, a, sp),
                                               torch.tensor(-10000., device=s.device)),
    }
    model_runner = RunnerWithModel(fake_model_fns, horizon, dim_state, [EpisodeReturn], n=1, device=device)
    # model_eval_runner = ModelRunner(fake_model_fns, horizon, dim_state, batch_size=1, device=device)
    # buf_fake = rl_utils.TorchReplayBuffer(env, max_buf_size=1000_000)

    if FLAGS.ckpt.policy != '':  # must be done before define policy_optimizer (policy target init)
        policy.load_state_dict(torch.load(FLAGS.ckpt.policy, map_location=device)['policy'])
        logger.info(f"Load policy from {FLAGS.ckpt.policy}")
    if FLAGS.ckpt.L != '' and FLAGS.h.type == 'learned':  # must be done before L_target is init
        barrier.load_state_dict(torch.load(FLAGS.ckpt.L, map_location=device)['L'])
        logger.info(f"Load L from {FLAGS.ckpt.L}")
        # assert FLAGS.should_optimize_policy
    if FLAGS.ckpt.models != '':
        model.load_state_dict(torch.load(FLAGS.ckpt.models, map_location=device)['models'])
        logger.info(f"Load model from {FLAGS.ckpt.models}")

    crabs_ref = deepcopy(crabs)
    expl_policy = ExplorationPolicy(policy, crabs_ref).to(device)
    expl_model_policy = ExplorationPolicy(rl_utils.policy.AddGaussianNoise(policy, 0.0, 2.0), crabs_ref).to(device)
    expl_rand_policy = ExplorationPolicy(rl_utils.policy.UniformPolicy(dim_action), crabs_ref).to(device)

    # if FLAGS.ckpt.safe_invariant != '':
    #     crabs_ref.load_state_dict(torch.load(FLAGS.ckpt.safe_invariant, map_location=device)['safe_invariant'])
    #     logger.info(f"Load SafeInvariant from {FLAGS.ckpt.safe_invariant}")

    hardD = lambda s: torch.where(barrier(s) <= 0, crabs.U(s), -barrier(s) - 100)
    softD = lambda s: crabs.U(s) - 100 * relu(barrier(s))

    policy_optimizer = SafeSACTrainer2(policy, [
            MLPQFn([dim_state + dim_action, 256, 256, 1]),
            MLPQFn([dim_state + dim_action, 256, 256, 1]),
        ], U,
        sampler=buf_real.sample,
        device=device,
        target_entropy=-dim_action,
    )

    # policy_optimizer = SACTrainer(policy, [
    #         MLPQFn([dim_state + dim_action, 256, 256, 1]),
    #         MLPQFn([dim_state + dim_action, 256, 256, 1]),
    #     ],
    #     sampler=buf_real.sample,
    #     device=device,
    #     target_entropy=-dim_action,
    # )

    state_box.find_box(crabs_ref)

    h_opt = LOptimizer(dim_state, crabs, nn.ModuleList([barrier, policy]).parameters(), crabs_ref.L).to(device)
    s_opt_langevin = SLangevinOptimizer(crabs, state_box).to(device)
    s_opt_sample = SSampleOptimizer(crabs, state_box).to(device)
    s_opt_grad = SGradOptimizer(crabs, state_box).to(device)
    # s_opt_block_grad = SBlockGradOptimizer(crabs, state_box, reset_block=1).to(device)
    s_opt = s_opt_langevin
    # s_opt = s_opt_block_grad

    fns = {'L': barrier, 'U': U, 'hardD': hardD, 'softD': softD, 'logBarrier': lambda x: env.env_barrier_fn(x).log()}
    # assert L(s0) <= 1
    # logger.info(f"L(s0) = {L(s0).item():.6f}")
    logger.debug(f"[normalizer]: mean = {normalizer.mean.cpu().numpy()}, std = {normalizer.std.cpu().numpy()}")

    debugger = Debugger(env, policy, mean_policy, barrier, model, runners['evaluate'], horizon, s0, s_opt, s_opt_grad,
                        s_opt_sample, h_opt, fns, buf_out, crabs, policy_optimizer.qfns, FLAGS)

    if FLAGS.task == 'plot_policy_safe_region':
        plot_policy_safe_region(env.trans_fn, mean_policy, device)

    elif FLAGS.task == 'check_L':
        debugger.evaluate(0, policy=True, mean_policy=True, virt_safe=True, s_grad=True, video=True)
        logger.info("pretrain s...")
        for i in range(100_000):
            if i % 5000 == 0 and i > 0:
                breakpoint()
            if i % 1_000 == 0:
                s_opt.evaluate(step=i)
            s_opt.step()

    elif FLAGS.task == 'debug':
        debugger.evaluate(0, video=True, mean_policy=True, plot=True)
        debugger.evaluate(1, video=True, mean_policy=True, plot=True)
        breakpoint()
        h = 1000
        state = s0
        model = ensemble.models[0]
        ensemble.to(device)
        for model in ensemble.models:
            model.test_stability(s0, mean_policy, horizon * 10)
        breakpoint()

        all_states = []
        for i in range(FLAGS.model.n_ensemble):
            # ensemble.elites = [4]
            model_runner.fns['transition'] = lambda s, a, i=i: ensemble.models[i](s, a)
            buf_fake = rl_utils.TorchReplayBuffer(env, max_buf_size=h * 10, device=device)
            model_runner.horizon = h * 10
            model_runner.reset()
            ep_infos_fake = model_runner.run(mean_policy, h * 10, buffer=buf_fake)
            all_states.append(buf_fake.state.to(device))
        breakpoint()

        # index = torch.nonzero(Vs > 1000)[0, 0].item() - 1
        # s = all_states[0][index]
        # a = mean_policy(s)
        # model(s, a)

        ep_infos_real = runners['evaluate'].run(mean_policy, h, buffer=buf_real)
        states = buf_real.state.to(device)
        actions = buf_real.action.to(device)
        next_states = buf_real.next_state.to(device)
        distribution = ensemble.models[0](states, actions, det=False)
        log_prob = distribution.log_prob(next_states)
        mse = ((distribution.mean - next_states) / normalizer.std).pow(2)
        # logger.info(f"max x = {states[:, 0].max()}")
        breakpoint()
        print('???')

    elif FLAGS.task == 'retreat_policy':
        batch_size = 256
        buf_tmp = rl_utils.TorchReplayBuffer(env, max_buf_size=1_000_000, device=device)
        M = 10_000
        all_states = torch.randn(M, dim_state, device=device) * normalizer.std + normalizer.mean

        state = s0

        actions = nn.Parameter(torch.zeros(batch_size, dim_action, device=device), requires_grad=True)
        optim = torch.optim.Adam([actions])
        #
        # for i in range(10000):
        #     action = mean_policy(state)
        #     state = ensemble.models[0](state, action).detach()
        #     if i % 100 == 0:
        #         print(state.norm().item(), (state - s0).norm().item(), state)

        for t in range(10000):
            if t % 100 == 0:
                print(f"# {t}: state norm = {state.norm()}")
            nn.init.uniform_(actions, -1, 1)
            for a in range(1000):
                predictions = ensemble.models[0](state.repeat(batch_size, 1).detach(), actions)
                loss = (predictions - state).pow(2).mean(dim=-1)
                if t % 100 == 0 and a % 100 == 0:
                    print(a, loss.min().item(), loss.mean().item(), loss.max().item())
                optim.zero_grad()
                loss.mean().backward()
                optim.step()

            action = actions[loss.argmin()]
            old_state = state
            state = ensemble.models[0](state, action).detach()
            print((old_state - state).norm())

        # for t in range(100000):
        #     states = all_states[torch.randint(M, size=[batch_size])]
        #     loss = nn.functional.mse_loss(ensemble.models[0](states, mean_policy(states)), s0)
        #     if t % 1000 == 0:
        #         debugger.evaluate(t, virt_safe=True)
        #         state = s0 + torch.randn_like(s0) * 0.1
        #         for i in range(100):
        #             with torch.no_grad():
        #                 all_states[np.random.randint(M)] = state
        #             state = ensemble.models[0](state, mean_policy(state))
        #         print("loss", loss.item())

            # optim.zero_grad()
            # loss.backward()
            # nn.utils.clip_grad_norm_(policy.parameters(), 0.5)
            # optim.step()

    elif FLAGS.task == 'pretrain_model':
        # expl_policy = UniformPolicy(dim_action)
        dev_infos = RunnerX(make_env, 10, device=device).run(
            rl_utils.policy.AddGaussianNoise(policy, 0, 2), horizon * 10, buffer=buf_dev)

        if not FLAGS.ckpt.buf:
            RunnerX(make_env, 10, device=device).run(rl_utils.policy.UniformPolicy(dim_action), horizon * 10, buffer=buf_dev)
            runner = RunnerX(make_env, 1, stats=[EpisodeReturn, lambda: ExtractLastInfo('episode.unsafe')], device=device)

            for noise in np.linspace(0, 1.0, 500):
                print(noise)
                runner.reset()
                runner.run(rl_utils.policy.AddGaussianNoise(policy, 0, noise), 1 * horizon, buf_real)
            runner.reset()
            runner.run(rl_utils.policy.UniformPolicy(dim_action), 200 * horizon, buf_real)

        print('dev', merge_episode_stats(dev_infos))
        normalizer.fit(buf_real.state)

        model_trainer.max_epochs += 20
        model_trainer.fit(ensemble, train_dataloader=buf_real.sampling_data_loader(1_000, 256),
                          val_dataloaders=buf_real.sampling_data_loader(1, 10_000))
        ensemble.to(device)
        for model in ensemble.models:
            model.test_stability(s0, mean_policy, horizon * 10)
        # train_models(model_trainers, n_steps=50000)

        if not FLAGS.ckpt.buf:
            with open(PQ.log_dir / 'buf.pkl', 'wb') as f:
                pickle.dump(buf_real, f)
        torch.save({
            'models': ensemble.state_dict(),
        }, PQ.log_dir / 'final.pt')

    elif FLAGS.task == 'pretrain_L':

        # don't tune state box
        state_box.reset(s0)
        debugger.evaluate(0, video=True, virt_safe=True)
        breakpoint()
        logger.info("pretrain s...")
        for i in range(FLAGS.n_pretrain_s_iters):
            if i % 1_000 == 0:
                # breakpoint()
                s_opt.evaluate(step=i)
            s_opt.step()

        # logger.critical("zeroing the policy!")
        # for p in policy.parameters():
        #     nn.init.zeros_(p)
        h_opt.opt_params.param_groups[0]['params'] = list(barrier.parameters())
        h_opt.L_ref = None
        for t in range(FLAGS.n_iters):
            if t % 1_000 == 0 and t > 0:
                logger.info(f"# iter {t}")
            debugger.evaluate(t, s=t % 1_000 == 0, video=t == 0, virt_safe=t % 10_000 == 0, save=t % 50_000 == 0,
                              plot=t % 10_000 == 0, s_grad=t % 50_000 == 0 and t > 0)
            for i in range(FLAGS.opt_s.n_steps):
                d = s_opt.step()['optimal']
                PQ.meters[f'opt_progress/{i}'] += d
            h_opt.step(s_opt.s)

            if h_opt.since_last_update > 2000 and s_opt.since_last_reset > 5000:
                state_box.reset(s0)
                state_box.find_box(crabs)
                s_opt.reinit()

    elif FLAGS.task == 'safe-init':
        # expl_policy = UniformPolicy(dim_action)
        dev_infos = RunnerX(make_env, 20, device=device).run(
            rl_utils.policy.AddGaussianNoise(policy, 0, 3), horizon * 50, buffer=buf_dev)
        # dev_infos = RunnerX(make_env, 20, device=device).run(
        #     rl_utils.policy.UniformPolicy(dim_action), horizon * 50, buffer=buf_dev)

        if FLAGS.env.id == 'SlimPointGoal1-v0':
            print(np.asarray(env.hazards_pos), env.goal_pos)

        if not FLAGS.ckpt.buf:
            RunnerX(make_env, 10, device=device).run(rl_utils.policy.UniformPolicy(dim_action), horizon * 10, buffer=buf_dev)
            runner = RunnerX(make_env, 1, stats=[EpisodeReturn, lambda: ExtractLastInfo('episode.unsafe')], device=device)

            for noise in np.linspace(0.2, 2, 500):
                runner.reset()
                buf_tmp = rl_utils.TorchReplayBuffer(env, max_buf_size=1_000)
                ep_infos = runner.run(rl_utils.policy.AddGaussianNoise(policy, 0, noise), 1 * horizon, buf_tmp)
                print(noise, ep_infos)
                if not ep_infos[0]['episode.unsafe']:
                    buf_real.add_transitions({
                        'state': buf_tmp.state,
                        'action': buf_tmp.action,
                        'next_state': buf_tmp.next_state,
                        'reward': buf_tmp.reward,
                        'done': buf_tmp.done,
                        'timeout': buf_tmp.timeout,
                    })
            # ep_infos = runner.run(rl_utils.policy.UniformPolicy(dim_action), 100 * horizon, buf_real)
            # print('rand', ep_infos)
            with open(PQ.log_dir / 'buf.pkl', 'wb') as f:
                pickle.dump(buf_real, f)

        print('dev', merge_episode_stats(dev_infos))
        normalizer.fit(buf_real.state)

        # train_models(model_trainers, n_steps=50000)
        model_trainer.max_epochs += 20
        model_trainer.fit(model, train_dataloader=buf_real.sampling_data_loader(1_000, 256),
                          val_dataloaders=buf_dev.sampling_data_loader(1, 50_000))
        for model in ensemble.models:
            model.test_stability(s0, mean_policy, horizon * 10)

        torch.save({
            'models': ensemble.state_dict(),
        }, PQ.log_dir / 'final.pt')

    elif FLAGS.task == 'unified-new-algo':
        logger.info("pretrain s...")
        for i in range(FLAGS.n_pretrain_s_iters):
            if i % 1000 == 0:
                s_opt.evaluate(step=i)
            s_opt.step()

        debugger.evaluate(0, video=True, plot=True)

        if FLAGS.ckpt.buf == '':
            # collect 10_000 samples
            buf_tmp = eval_and_explore(mean_policy, expl_policy, runners, 10_000, buf_real, crabs_ref)
            nlls = ensemble.get_nlls(buf_tmp['state'], buf_tmp['action'], buf_tmp['next_state'])
            PQ.log.debug(f'init expl as val: {nlls}')
            torch.save(buf_tmp, PQ.log_dir / 'init-buf.pt')

        global_step = 0

        model_trainer.max_epochs += 5
        model_trainer.fit(model, train_dataloader=buf_real.sampling_data_loader(1_000, 256))
        model.to(device)  # pytorch lightning transferred the model to cpu

        # warming up Q
        policy_optimizer.can_update_policy = False
        for _ in range(10_001):
            policy_optimizer.step()
        policy_optimizer.can_update_policy = True

        freq = 0.5
        y_step = 0

        for epoch in range(100):
            eval_and_explore(mean_policy, expl_model_policy, runners, horizon * 2, buf_real, crabs_ref)
            eval_and_explore(mean_policy, expl_rand_policy, runners, horizon * 2, buf_real, crabs_ref)

            model_trainer.max_epochs += 1
            model_trainer.fit(ensemble, train_dataloader=buf_real.sampling_data_loader(1_000, 256))
            ensemble.to(device)  # pytorch lightning transferred the model to cpu
            # see pytorch_lightning/trainer/training_loop.py:220

            for _ in range(2_000):
                s_opt.step()

            # train policy
            logger.info(f"Epoch {epoch}: train policy, safety req freq = {freq:.3f}")
            policy_optimizer.can_update_policy = True
            global_step = bump(global_step, 1000_000_000)
            # debugger.evaluate(global_step, video=True, plot=True)
            h_opt.opt_params.param_groups[0]['params'] = list(policy.parameters())   # + list(L.parameters())
            debugger.evaluate(epoch * 2000, mean_policy=True)
            for t in range(2_000):
                global_step = bump(global_step, 1)
                if t % 1000 == 0:
                    buf_tmp = eval_and_explore(mean_policy, expl_policy, runners, horizon * 1, buf_real, crabs_ref)

                    debugger.last_expl = buf_tmp['state'].cpu().detach().numpy()
                    torch.save(buf_tmp, PQ.log_dir / f'buf-epoch_{epoch}-t_{t}.pt')
                    torch.save(buf_real, PQ.log_dir / f'buf.pt')
                if t % 1000 == 0:
                    debugger.evaluate(global_step, s=True, video=True, expl=True, virt_safe=True, plot=True)

                if len(buf_real) > 1000:
                    policy_optimizer.step()   # optimize unsafe policy

                for i in range(FLAGS.opt_s.n_steps):  # opt s
                    s_opt.step()
                # if t % 5 == 0:
                y_step += 1
                while y_step >= freq:
                    h_opt.step(s_opt.s)
                    y_step -= freq
            global_step = bump(global_step, 1)
            debugger.evaluate(global_step, video=True, plot=True)

            # check if virt_safe
            global_step = bump(global_step, 1_000_000)
            debugger.evaluate(global_step, virt_safe=True)

            # train L
            logger.info(f"Epoch {epoch}: train L!")
            h_opt.opt_params.param_groups[0]['params'] = list(barrier.parameters())  # + list(policy.parameters())
            global_step = bump(global_step, 1_000_000)
            total_updates = 0

            for _ in range(2000):
                s_opt.step()

            h_status = 'training'
            h_opt.since_last_update = 0
            for t in range(20_000):
                global_step = bump(global_step, 1)
                if t % 1_000 == 0:
                    logger.info(f"# iter {t}")
                debugger.evaluate(global_step, s=t % 1_000 == 0 and t != 0,
                                  virt_safe=t % 10_000 == 0, plot=t % 10_000 == 0, s_grad=t % 20_000 == 0 and t > 0)
                for i in range(FLAGS.opt_s.n_steps):
                    s_opt.step()
                result = h_opt.step(s_opt.s)
                if result['mask'].sum() > 0.0:
                    total_updates += 1
                    h_status = 'training'

                if h_status == 'training' and h_opt.since_last_update >= 1000:
                    logger.info("resetting SGLD, entering observation period")
                    state_box.find_box(crabs_ref)
                    s_opt.reinit()
                    h_status = 'observation-period'

                # if t == 1_000 and total_updates / t > 0.99:
                #     barrier.load_state_dict(crabs_ref.barrier.state_dict())
                #     logger.warning("early stop... unlikely to find a barrier")
                #     break

                if h_status == 'observation-period' and h_opt.since_last_update == 5_000:
                    crabs_ref.load_state_dict(crabs.state_dict())
                    logger.info(f"win streak at {t} => find a new invariant => update ref. ")
                    if t == 4_999:  # policy is too conservative, reduce safe constraint
                        freq = np.clip(2 * freq, 0.1, 10)
                        logger.info(f"reduce frequency to {freq:.3f}")
                    else:
                        freq = np.clip(freq * 1.2, 0.1, 10)  # to encourage use less policy safety objective
                    break
            else:
                barrier.load_state_dict(crabs_ref.barrier.state_dict())
                freq = np.clip(freq / 2, 0.1, 10)
                logger.warning(f"can't find barrier, increase freq to {freq}")
            global_step = bump(global_step, 1)
            debugger.evaluate(epoch, s=True, virt_safe=True, save=True, plot=True, s_grad=True)
    elif FLAGS.task == 'train-fixed-h':
        logger.info("pretrain s...")
        for i in range(FLAGS.n_pretrain_s_iters):
            if i % 1000 == 0:
                s_opt.evaluate(step=i)
            s_opt.step()

        debugger.evaluate(0, video=True, plot=True)

        if FLAGS.ckpt.buf == '':
            # collect 10_000 samples
            buf_tmp = eval_and_explore(mean_policy, expl_policy, runners, 10_000, buf_real, crabs_ref)
            nlls = ensemble.get_nlls(buf_tmp['state'], buf_tmp['action'], buf_tmp['next_state'])
            PQ.log.debug(f'init expl as val: {nlls}')
            torch.save(buf_tmp, PQ.log_dir / 'init-buf.pt')

        global_step = 0

        model_trainer.max_epochs += 5
        model_trainer.fit(model, train_dataloader=buf_real.sampling_data_loader(1_000, 256))
        model.to(device)  # pytorch lightning transferred the model to cpu

        # warming up Q
        policy_optimizer.can_update_policy = False
        for _ in range(10_001):
            policy_optimizer.step()
        policy_optimizer.can_update_policy = True

        freq = 0.5
        y_step = 0

        for epoch in range(100):
            eval_and_explore(mean_policy, expl_model_policy, runners, horizon * 10, buf_real, crabs_ref)
            eval_and_explore(mean_policy, expl_rand_policy, runners, horizon * 10, buf_real, crabs_ref)

            model_trainer.max_epochs += 1
            model_trainer.fit(ensemble, train_dataloader=buf_real.sampling_data_loader(1_000, 256))
            ensemble.to(device)  # pytorch lightning transferred the model to cpu
            # see pytorch_lightning/trainer/training_loop.py:220

            for _ in range(2_000):
                s_opt.step()

            # train policy
            logger.info(f"Epoch {epoch}: train policy, safety req freq = {freq:.3f}")
            policy_optimizer.can_update_policy = True
            global_step = bump(global_step, 1000_000_000)
            # debugger.evaluate(global_step, video=True, plot=True)
            h_opt.opt_params.param_groups[0]['params'] = list(policy.parameters())   # + list(L.parameters())
            debugger.evaluate(epoch * 2000, mean_policy=True)
            for t in range(2_000):
                global_step = bump(global_step, 1)
                if t % 1000 == 0:
                    buf_tmp = eval_and_explore(mean_policy, expl_policy, runners, horizon * 5, buf_real, crabs_ref)

                    debugger.last_expl = buf_tmp['state'].cpu().detach().numpy()
                    torch.save(buf_tmp, PQ.log_dir / f'buf-epoch_{epoch}-t_{t}.pt')
                    torch.save(buf_real, PQ.log_dir / f'buf.pt')
                if t % 1000 == 0:
                    debugger.evaluate(global_step, s=True, video=True, expl=True, virt_safe=True, plot=True)

                if len(buf_real) > 1000:
                    policy_optimizer.step()   # optimize unsafe policy

                for i in range(FLAGS.opt_s.n_steps):  # opt s
                    s_opt.step()
                # if t % 5 == 0:
                y_step += 1
                while y_step >= freq:
                    h_opt.step(s_opt.s)
                    y_step -= freq
            global_step = bump(global_step, 1)
            debugger.evaluate(global_step, video=True, plot=True)

            # check if virt_safe
            global_step = bump(global_step, 1_000_000)
            debugger.evaluate(global_step, virt_safe=True)

            global_step = bump(global_step, 1)
            debugger.evaluate(epoch, s=True, virt_safe=True, save=True, plot=True, s_grad=True)
    else:
        assert 0, f"invalid task `{FLAGS.task}`"


if __name__ == '__main__':
    main()
