import warnings
from typing import Sequence, Union, Dict
from shutil import copyfile
import inspect
from collections import OrderedDict
import multiprocessing
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR, ReduceLROnPlateau
from einops import rearrange
import torchmetrics
from pytorch_lightning import Trainer, seed_everything, loggers as pl_loggers
from pytorch_lightning.profilers import PyTorchProfiler
from pytorch_lightning.strategies.ddp import DDPStrategy
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, DeviceStatsMonitor, Callback
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from omegaconf import OmegaConf
import os
import argparse

from prediff.datasets.nbody.nbody_mnist_torch_wrap import NBodyMovingMNISTLightningDataModule
from prediff.datasets.nbody.nbody_mnist import default_datasets_dir
from prediff.datasets.nbody.visualization import vis_nbody_energy
from prediff.utils.checkpoint import pl_load
from prediff.utils.optim import SequentialLR, warmup_lambda
from prediff.utils.layout import layout_to_in_out_slice
from prediff.taming.vae import AutoencoderKL
from prediff.diffusion.guidance.guidance_pl import GuidancePL
from prediff.diffusion.guidance.nbody.energy_predictor import NbodyGuidanceEnergy


pytorch_state_dict_name = "kc_nbody.pt"
exps_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "experiments"))
pretrained_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "pretrained"))

class NbodyGuidancePLModule(GuidancePL):

    def __init__(self,
                 total_num_steps: int,
                 oc_file: str = None,
                 save_dir: str = None):
        self.total_num_steps = total_num_steps
        if oc_file is not None:
            oc_from_file = OmegaConf.load(open(oc_file, "r"))
        else:
            oc_from_file = None
        oc = self.get_base_config(oc_from_file=oc_from_file)
        self.save_hyperparameters(oc)
        self.oc = oc

        guide_obj_cfg = OmegaConf.to_object(oc.model.guide_obj)
        self.guide_obj = NbodyGuidanceEnergy(
            guide_type=guide_obj_cfg["guide_type"],
            out_len=guide_obj_cfg["out_len"],
            model_type=guide_obj_cfg["model_type"],
            model_args=guide_obj_cfg["model_args"],)

        vae_cfg = OmegaConf.to_object(oc.model.vae)
        first_stage_model = AutoencoderKL(
            down_block_types=vae_cfg["down_block_types"],
            in_channels=vae_cfg["in_channels"],
            block_out_channels=vae_cfg["block_out_channels"],
            act_fn=vae_cfg["act_fn"],
            latent_channels=vae_cfg["latent_channels"],
            up_block_types=vae_cfg["up_block_types"],
            norm_num_groups=vae_cfg["norm_num_groups"],
            layers_per_block=vae_cfg["layers_per_block"],
            out_channels=vae_cfg["out_channels"], )
        pretrained_ckpt_path = vae_cfg["pretrained_ckpt_path"]
        if pretrained_ckpt_path is not None and os.path.exists(os.path.join(pretrained_dir, pretrained_ckpt_path)):
            state_dict = torch.load(os.path.join(pretrained_dir, pretrained_ckpt_path),
                                    map_location=torch.device("cpu"))
            first_stage_model.load_state_dict(state_dict=state_dict)
        else:
            warnings.warn(f"Pretrained weights for `AutoencoderKL` not set. Run for sanity check only.")

        diffusion_cfg = OmegaConf.to_object(oc.model.diffusion)
        super(NbodyGuidancePLModule, self).__init__(
            torch_nn_module=self.guide_obj.model,
            target_fn=self.guide_obj.model_objective,
            layout=oc.layout.layout,
            timesteps=diffusion_cfg["timesteps"],
            beta_schedule=diffusion_cfg["beta_schedule"],
            loss_type=self.oc.optim.loss_type,
            monitor=self.oc.optim.monitor,
            linear_start=diffusion_cfg["linear_start"],
            linear_end=diffusion_cfg["linear_end"],
            cosine_s=diffusion_cfg["cosine_s"],
            given_betas=diffusion_cfg["given_betas"],
            # latent diffusion
            first_stage_model=first_stage_model,
            cond_stage_model=diffusion_cfg["cond_stage_model"],
            num_timesteps_cond=diffusion_cfg["num_timesteps_cond"],
            cond_stage_trainable=diffusion_cfg["cond_stage_trainable"],
            cond_stage_forward=diffusion_cfg["cond_stage_forward"],
            scale_by_std=diffusion_cfg["scale_by_std"],
            scale_factor=diffusion_cfg["scale_factor"],)
        # lr_scheduler
        self.total_num_steps = total_num_steps
        # logging
        self.save_dir = save_dir
        self.logging_prefix = oc.logging.logging_prefix
        # visualization
        self.train_example_data_idx_list = list(oc.vis.train_example_data_idx_list)
        self.val_example_data_idx_list = list(oc.vis.val_example_data_idx_list)
        self.test_example_data_idx_list = list(oc.vis.test_example_data_idx_list)
        self.eval_example_only = oc.vis.eval_example_only
        
        self.valid_mse_list = nn.ModuleList()
        self.valid_mae_list = nn.ModuleList()
        self.test_mse_list = nn.ModuleList()
        self.test_mae_list = nn.ModuleList()
        for t in range(self.oc.vis.denoise_t_step - 1, self.num_timesteps, self.oc.vis.denoise_t_step):
            self.valid_mse_list.append(torchmetrics.MeanSquaredError())
            self.valid_mae_list.append(torchmetrics.MeanAbsoluteError())
            self.test_mse_list.append(torchmetrics.MeanSquaredError())
            self.test_mae_list.append(torchmetrics.MeanAbsoluteError())
        self.valid_mse = torchmetrics.MeanSquaredError()
        self.valid_mae = torchmetrics.MeanAbsoluteError()
        self.test_mse = torchmetrics.MeanSquaredError()
        self.test_mae = torchmetrics.MeanAbsoluteError()
        self.configure_save(cfg_file_path=oc_file)

    def configure_save(self, cfg_file_path=None):
        self.save_dir = os.path.join(exps_dir, self.save_dir)
        os.makedirs(self.save_dir, exist_ok=True)
        if cfg_file_path is not None:
            cfg_file_target_path = os.path.join(self.save_dir, "cfg.yaml")
            if (not os.path.exists(cfg_file_target_path)) or \
                    (not os.path.samefile(cfg_file_path, cfg_file_target_path)):
                copyfile(cfg_file_path, cfg_file_target_path)
        self.example_save_dir = os.path.join(self.save_dir, "examples")
        os.makedirs(self.example_save_dir, exist_ok=True)

    def get_base_config(self, oc_from_file=None):
        oc = OmegaConf.create()
        oc.layout = self.get_layout_config()
        oc.optim = self.get_optim_config()
        oc.logging = self.get_logging_config()
        oc.trainer = self.get_trainer_config()
        oc.vis = self.get_vis_config()
        oc.model = self.get_model_config()
        oc.dataset = self.get_dataset_config()
        if oc_from_file is not None:
            # oc = apply_omegaconf_overrides(oc, oc_from_file)
            oc = OmegaConf.merge(oc, oc_from_file)
        return oc

    @staticmethod
    def get_layout_config():
        cfg = OmegaConf.create()
        cfg.in_len = 10
        cfg.out_len = 10
        cfg.img_height = 64
        cfg.img_width = 64
        cfg.data_channels = 1
        cfg.layout = "NTHWC"
        return cfg

    @classmethod
    def get_model_config(cls):
        cfg = OmegaConf.create()
        layout_cfg = cls.get_layout_config()
        
        cfg.diffusion = OmegaConf.create()
        cfg.diffusion.timesteps = 1000
        cfg.diffusion.beta_schedule = "linear"
        cfg.diffusion.linear_start = 1e-4
        cfg.diffusion.linear_end = 2e-2
        cfg.diffusion.cosine_s = 8e-3
        cfg.diffusion.given_betas = None
        # latent diffusion
        cfg.diffusion.cond_stage_model = "__is_first_stage__"
        cfg.diffusion.num_timesteps_cond = None
        cfg.diffusion.cond_stage_trainable = False
        cfg.diffusion.cond_stage_forward = None
        cfg.diffusion.scale_by_std = False
        cfg.diffusion.scale_factor = 1.0

        cfg.guide_obj = OmegaConf.create()
        cfg.guide_obj.guide_type = "sum_energy"
        cfg.guide_obj.out_len = 10
        cfg.guide_obj.model_type = "cuboid"
        cfg.guide_obj.model_args = OmegaConf.create()
        cfg.guide_obj.model_args.input_shape = [10, 16, 16, 4]
        cfg.guide_obj.model_args.out_channels = 2
        cfg.guide_obj.model_args.base_units = 16
        cfg.guide_obj.model_args.block_units = None
        cfg.guide_obj.model_args.scale_alpha = 1.0
        cfg.guide_obj.model_args.depth = [1, 1]
        cfg.guide_obj.model_args.downsample = 2
        cfg.guide_obj.model_args.downsample_type = "patch_merge"
        cfg.guide_obj.model_args.block_attn_patterns = "axial"
        cfg.guide_obj.model_args.num_heads = 4
        cfg.guide_obj.model_args.attn_drop = 0.0
        cfg.guide_obj.model_args.proj_drop = 0.0
        cfg.guide_obj.model_args.ffn_drop = 0.0
        cfg.guide_obj.model_args.ffn_activation = "gelu"
        cfg.guide_obj.model_args.gated_ffn = False
        cfg.guide_obj.model_args.norm_layer = "layer_norm"
        cfg.guide_obj.model_args.use_inter_ffn = True
        cfg.guide_obj.model_args.hierarchical_pos_embed = False
        cfg.guide_obj.model_args.pos_embed_type = 't+h+w'
        cfg.guide_obj.model_args.padding_type = "zero"
        cfg.guide_obj.model_args.checkpoint_level = 0
        cfg.guide_obj.model_args.use_relative_pos = True
        cfg.guide_obj.model_args.self_attn_use_final_proj = True
        # global vectors
        cfg.guide_obj.model_args.num_global_vectors = 0
        cfg.guide_obj.model_args.use_global_vector_ffn = True
        cfg.guide_obj.model_args.use_global_self_attn = False
        cfg.guide_obj.model_args.separate_global_qkv = False
        cfg.guide_obj.model_args.global_dim_ratio = 1
        # initialization
        cfg.guide_obj.model_args.attn_linear_init_mode = "0"
        cfg.guide_obj.model_args.ffn_linear_init_mode = "0"
        cfg.guide_obj.model_args.ffn2_linear_init_mode = "2"
        cfg.guide_obj.model_args.attn_proj_linear_init_mode = "2"
        cfg.guide_obj.model_args.conv_init_mode = "0"
        cfg.guide_obj.model_args.down_linear_init_mode = "0"
        cfg.guide_obj.model_args.global_proj_linear_init_mode = "2"
        cfg.guide_obj.model_args.norm_init_mode = "0"
        # timestep embedding for diffusion
        cfg.guide_obj.model_args.time_embed_channels_mult = 4
        cfg.guide_obj.model_args.time_embed_use_scale_shift_norm = False
        cfg.guide_obj.model_args.time_embed_dropout = 0.0
        # readout
        cfg.guide_obj.model_args.pool = "attention"
        cfg.guide_obj.model_args.readout_seq = True
        cfg.guide_obj.model_args.out_len = 10

        cfg.vae = OmegaConf.create()
        cfg.vae.data_channels = layout_cfg.data_channels
        # from stable-diffusion-v1-5
        cfg.vae.down_block_types = ['DownEncoderBlock2D', 'DownEncoderBlock2D', 'DownEncoderBlock2D', 'DownEncoderBlock2D']
        cfg.vae.in_channels = cfg.vae.data_channels
        cfg.vae.block_out_channels = [128, 256, 512, 512]
        cfg.vae.act_fn = 'silu'
        cfg.vae.latent_channels = 4
        cfg.vae.up_block_types = ['UpDecoderBlock2D', 'UpDecoderBlock2D', 'UpDecoderBlock2D', 'UpDecoderBlock2D']
        cfg.vae.norm_num_groups = 32
        cfg.vae.layers_per_block = 2
        cfg.vae.out_channels = cfg.vae.data_channels
        return cfg

    @classmethod
    def get_dataset_config(cls):
        cfg = OmegaConf.create()
        cfg.dataset_name = "nbody20k_digits3_len20_size64"
        cfg.num_train_samples = 20000
        cfg.num_val_samples = 1000
        cfg.num_test_samples = 1000
        cfg.digit_num = None
        cfg.img_size = 64
        cfg.raw_img_size = 128
        cfg.seq_len = 1
        cfg.raw_seq_len_multiplier = 5
        cfg.distractor_num = None
        cfg.distractor_size = 5
        cfg.max_velocity_scale = 2.0
        cfg.initial_velocity_range = [0.0, 2.0]
        cfg.random_acceleration_range = [0.0, 0.0]
        cfg.scale_variation_range = [1.0, 1.0]
        cfg.rotation_angle_range = [-0, 0]
        cfg.illumination_factor_range = [1.0, 1.0]
        cfg.period = 5
        cfg.global_rotation_prob = 0.5
        cfg.index_range = [0, 40000]
        cfg.mnist_data_path = None
        cfg.aug_mode = "0"
        cfg.ret_contiguous = False
        cfg.energy_norm_scale = 0.1
        # N-body params
        cfg.nbody_acc_mode = "r0"
        cfg.nbody_G = 0.035
        cfg.nbody_softening_distance = 0.01
        cfg.nbody_mass = None
        return cfg

    @staticmethod
    def get_optim_config():
        cfg = OmegaConf.create()
        cfg.seed = None
        cfg.total_batch_size = 32
        cfg.micro_batch_size = 8
        cfg.float32_matmul_precision = "high"

        cfg.method = "adamw"
        cfg.lr = 1.0E-6
        cfg.wd = 1.0E-2
        cfg.betas = (0.9, 0.999)
        cfg.gradient_clip_val = 1.0
        cfg.max_epochs = 50
        cfg.loss_type = "l2"
        # scheduler
        cfg.warmup_percentage = 0.2
        cfg.lr_scheduler_mode = "cosine"  # Can be strings like 'linear', 'cosine', 'plateau'
        cfg.min_lr_ratio = 1.0E-3
        cfg.warmup_min_lr_ratio = 0.0
        cfg.plateau_patience = 5  # take effect when `lr_scheduler_mode` is "plateau". Number of epochs with no improvement after which learning rate will be reduced.
        # early stopping
        cfg.monitor = "valid_loss_epoch"
        cfg.early_stop = False
        cfg.early_stop_mode = "min"
        cfg.early_stop_patience = 5
        cfg.save_top_k = 1
        return cfg

    @staticmethod
    def get_logging_config():
        cfg = OmegaConf.create()
        cfg.logging_prefix = "Nbody_Energy"
        cfg.monitor_lr = True
        cfg.monitor_device = False
        cfg.track_grad_norm = -1
        cfg.use_wandb = False
        cfg.profiler = None
        return cfg

    @staticmethod
    def get_trainer_config():
        cfg = OmegaConf.create()
        cfg.check_val_every_n_epoch = 1
        cfg.log_step_ratio = 0.001  # Logging every 1% of the total training steps per epoch
        cfg.precision = 32
        cfg.find_unused_parameters = True
        cfg.num_sanity_val_steps = 2
        return cfg

    @staticmethod
    def get_vis_config():
        cfg = OmegaConf.create()
        cfg.train_example_data_idx_list = []
        cfg.val_example_data_idx_list = []
        cfg.test_example_data_idx_list = []
        cfg.eval_example_only = False
        cfg.denoise_t_step = 100
        return cfg

    def configure_optimizers(self):
        optim_cfg = self.oc.optim
        params = list(self.torch_nn_module.parameters())
        if self.cond_stage_trainable:
            print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
            params = params + list(self.cond_stage_model.parameters())

        if optim_cfg.method == "adamw":
            optimizer = torch.optim.AdamW(params, lr=optim_cfg.lr, betas=optim_cfg.betas)
        else:
            raise NotImplementedError(f"opimization method {optim_cfg.method} not supported.")

        warmup_iter = int(np.round(self.oc.optim.warmup_percentage * self.total_num_steps))
        if optim_cfg.lr_scheduler_mode == 'none':
            return {'optimizer': optimizer}
        else:
            if optim_cfg.lr_scheduler_mode == 'cosine':
                warmup_scheduler = LambdaLR(optimizer,
                                            lr_lambda=warmup_lambda(warmup_steps=warmup_iter,
                                                                    min_lr_ratio=optim_cfg.warmup_min_lr_ratio))
                cosine_scheduler = CosineAnnealingLR(optimizer,
                                                     T_max=(self.total_num_steps - warmup_iter),
                                                     eta_min=optim_cfg.min_lr_ratio * optim_cfg.lr)
                lr_scheduler = SequentialLR(optimizer, schedulers=[warmup_scheduler, cosine_scheduler],
                                            milestones=[warmup_iter])
                lr_scheduler_config = {
                    'scheduler': lr_scheduler,
                    'interval': 'step',
                    'frequency': 1,
                }
            elif optim_cfg.lr_scheduler_mode == 'plateau':
                lr_scheduler = ReduceLROnPlateau(
                    optimizer=optimizer,
                    mode='min',
                    patience=self.oc.optim.plateau_patience,
                    factor=0.1,
                    verbose=True,
                )
                lr_scheduler_config = {
                    "scheduler": lr_scheduler,
                    "interval": "epoch",
                    "frequency": self.oc.trainer.check_val_every_n_epoch,
                    "monitor": self.oc.optim.monitor,
                }
            else:
                raise NotImplementedError
            return {'optimizer': optimizer, 'lr_scheduler': lr_scheduler_config}

    def set_trainer_kwargs(self, **kwargs):
        r"""
        Default kwargs used when initializing pl.Trainer
        """
        if self.oc.logging.profiler is None:
            profiler = None
        elif self.oc.logging.profiler == "pytorch":
            profiler = PyTorchProfiler(filename=f"{self.oc.logging.logging_prefix}_PyTorchProfiler.log")
        else:
            raise NotImplementedError
        checkpoint_callback = ModelCheckpoint(
            monitor=self.oc.optim.monitor,
            dirpath=os.path.join(self.save_dir, "checkpoints"),
            filename="{epoch:03d}",
            auto_insert_metric_name=False,
            save_top_k=self.oc.optim.save_top_k,
            save_last=True,
            mode="min",
        )
        callbacks = kwargs.pop("callbacks", [])
        assert isinstance(callbacks, list)
        for ele in callbacks:
            assert isinstance(ele, Callback)
        callbacks += [checkpoint_callback, ]
        if self.oc.logging.monitor_lr:
            callbacks += [LearningRateMonitor(logging_interval='step'), ]
        if self.oc.logging.monitor_device:
            callbacks += [DeviceStatsMonitor(), ]
        if self.oc.optim.early_stop:
            callbacks += [EarlyStopping(monitor="valid_loss_epoch",
                                        min_delta=0.0,
                                        patience=self.oc.optim.early_stop_patience,
                                        verbose=False,
                                        mode=self.oc.optim.early_stop_mode), ]

        logger = kwargs.pop("logger", [])
        tb_logger = pl_loggers.TensorBoardLogger(save_dir=self.save_dir)
        csv_logger = pl_loggers.CSVLogger(save_dir=self.save_dir)
        logger += [tb_logger, csv_logger]
        if self.oc.logging.use_wandb:
            wandb_logger = pl_loggers.WandbLogger(project=self.oc.logging.logging_prefix,
                                                  save_dir=self.save_dir)
            logger += [wandb_logger, ]

        log_every_n_steps = max(1, int(self.oc.trainer.log_step_ratio * self.total_num_steps))
        trainer_init_keys = inspect.signature(Trainer).parameters.keys()
        ret = dict(
            callbacks=callbacks,
            # log
            logger=logger,
            log_every_n_steps=log_every_n_steps,
            track_grad_norm=self.oc.logging.track_grad_norm,
            profiler=profiler,
            # save
            default_root_dir=self.save_dir,
            # ddp
            accelerator="gpu",
            strategy=DDPStrategy(find_unused_parameters=self.oc.trainer.find_unused_parameters),
            # strategy=ApexDDPStrategy(find_unused_parameters=False, delay_allreduce=True),
            # optimization
            max_epochs=self.oc.optim.max_epochs,
            check_val_every_n_epoch=self.oc.trainer.check_val_every_n_epoch,
            gradient_clip_val=self.oc.optim.gradient_clip_val,
            # NVIDIA amp
            precision=self.oc.trainer.precision,
            # misc
            num_sanity_val_steps=self.oc.trainer.num_sanity_val_steps,
            inference_mode=False,
        )
        oc_trainer_kwargs = OmegaConf.to_object(self.oc.trainer)
        oc_trainer_kwargs = {key: val for key, val in oc_trainer_kwargs.items() if key in trainer_init_keys}
        ret.update(oc_trainer_kwargs)
        ret.update(kwargs)
        return ret

    @classmethod
    def get_total_num_steps(
            cls,
            num_samples: int,
            total_batch_size: int,
            epoch: int = None):
        r"""
        Parameters
        ----------
        num_samples:    int
            The number of samples of the datasets. `num_samples / micro_batch_size` is the number of steps per epoch.
        total_batch_size:   int
            `total_batch_size == micro_batch_size * world_size * grad_accum`
        """
        if epoch is None:
            epoch = cls.get_optim_config().max_epochs
        return int(epoch * num_samples / total_batch_size)

    @staticmethod
    def get_nbody_datamodule(dataset_oc,
                             load_dir: str = None,
                             micro_batch_size: int = 1,
                             num_workers: int = 8):
        if load_dir is None:
            load_dir = os.path.join(default_datasets_dir, "nbody")
        data_dir = os.path.join(load_dir, dataset_oc["dataset_name"])
        if not os.path.exists(data_dir):
            raise ValueError(f"dataset in {data_dir} not exists!")
        load_dataset_cfg_path = os.path.join(data_dir, "nbody_dataset_cfg.yaml")
        load_dataset_cfg = OmegaConf.to_object(OmegaConf.load(open(load_dataset_cfg_path, "r")).dataset)
        for key, val in load_dataset_cfg.items():
            if key in ["aug_mode", "ret_contiguous"]:
                continue  # exclude keys that can be different
            assert val == dataset_oc[key], \
                f"dataset config {key} mismatches!" \
                f"{dataset_oc[key]} specified, but {val} loaded."
        dm = NBodyMovingMNISTLightningDataModule(
            data_dir=data_dir,
            force_regenerate=False,
            num_train_samples=dataset_oc["num_train_samples"],
            num_val_samples=dataset_oc["num_val_samples"],
            num_test_samples=dataset_oc["num_test_samples"],
            digit_num=dataset_oc["digit_num"],
            img_size=dataset_oc["img_size"],
            raw_img_size=dataset_oc["raw_img_size"],
            seq_len=dataset_oc["seq_len"],
            raw_seq_len_multiplier=dataset_oc["raw_seq_len_multiplier"],
            distractor_num=dataset_oc["distractor_num"],
            distractor_size=dataset_oc["distractor_size"],
            max_velocity_scale=dataset_oc["max_velocity_scale"],
            initial_velocity_range=dataset_oc["initial_velocity_range"],
            random_acceleration_range=dataset_oc["random_acceleration_range"],
            scale_variation_range=dataset_oc["scale_variation_range"],
            rotation_angle_range=dataset_oc["rotation_angle_range"],
            illumination_factor_range=dataset_oc["illumination_factor_range"],
            period=dataset_oc["period"],
            global_rotation_prob=dataset_oc["global_rotation_prob"],
            index_range=dataset_oc["index_range"],
            mnist_data_path=dataset_oc["mnist_data_path"],
            aug_mode=dataset_oc["aug_mode"],
            ret_contiguous=dataset_oc["ret_contiguous"],
            ret_aux=True,
            energy_norm_scale=dataset_oc["energy_norm_scale"],
            # N-Body params
            nbody_acc_mode=dataset_oc["nbody_acc_mode"],
            nbody_G=dataset_oc["nbody_G"],
            nbody_softening_distance=dataset_oc["nbody_softening_distance"],
            nbody_mass=dataset_oc["nbody_mass"],
            # datamodule_only
            batch_size=micro_batch_size,
            num_workers=num_workers, )
        return dm

    @property
    def in_slice(self):
        if not hasattr(self, "_in_slice"):
            in_slice, out_slice = layout_to_in_out_slice(
                layout=self.oc.layout.layout,
                in_len=self.oc.layout.in_len,
                out_len=self.oc.layout.out_len)
            self._in_slice = in_slice
            self._out_slice = out_slice
        return self._in_slice

    @property
    def out_slice(self):
        if not hasattr(self, "_out_slice"):
            in_slice, out_slice = layout_to_in_out_slice(
                layout=self.oc.layout.layout,
                in_len=self.oc.layout.in_len,
                out_len=self.oc.layout.out_len)
            self._in_slice = in_slice
            self._out_slice = out_slice
        return self._out_slice

    @property
    def energy_slice(self):
        if not hasattr(self, "_energy_slice"):
            _energy_slice = [slice(None, None), ] * 3
            t_axis = self.oc.layout.layout.find("T")
            _energy_slice[t_axis] = slice(-self.oc.model.guide_obj.out_len, None)
            self._energy_slice = _energy_slice
        return self._energy_slice

    @torch.no_grad()
    def get_input(self, batch, **kwargs):
        r"""
        dataset dependent
        re-implement it for each specific dataset

        Parameters
        ----------
        batch:  Any
            raw data batch from specific dataloader

        Returns
        -------
        out:    Sequence[torch.Tensor, Dict[str, Any]]
            out[0] should be a torch.Tensor which is the target to generate
            out[1] should be a dict consists of several key-value pairs for conditioning
        """
        return self._get_input_nbody(batch=batch, return_verbose=kwargs.get("return_verbose", False))

    @torch.no_grad()
    def _get_input_nbody(self, batch, return_verbose=False):
        r"""
        Returns
        -------
        out_seq:    torch.Tensor
            shape = (b, t_out, h, w, c)
        {"y": in_seq}:  Dict[str, torch.Tensor]
            in_seq.shape = (b, t_in, h, w, c)
        {"energy": energy}:  Dict[str, torch.Tensor]
            energy.shape = (b, t_out, 2)
        """
        seq, KE, PE = batch
        in_seq = seq[self.in_slice]
        out_seq = seq[self.out_slice]
        energy = torch.stack([KE, PE], dim=-1)[self.energy_slice]
        return out_seq, {"y": in_seq}, {"energy": energy}

    def training_step(self, batch, batch_idx):
        loss, loss_dict, verbose_dict = self(batch, return_verbose=True)
        self.log("train_loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=False)
        loss_dict = {f"train/{key}": val for key, val in loss_dict.items()}
        self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=False)
        micro_batch_size = batch[0].shape[0]
        data_idx = int(batch_idx * micro_batch_size)
        if self.current_epoch % self.oc.trainer.check_val_every_n_epoch == 0 \
                and self.local_rank == 0:
            out_seq, in_seq_dict, energy_dict = self.get_input(batch)
            target = energy_dict["energy"][0].detach().float().cpu().numpy()
            pred = np.zeros_like(target)
            pred[-self.oc.model.guide_obj.out_len:, :] = verbose_dict["pred"][0].detach().float().cpu().numpy()
            self.save_vis_step_end(
                data_idx=data_idx,
                target_seq=target,
                pred_seq=pred,
                mode="train",
                suffix=f"_t{int(verbose_dict['t'][0].item())}")
        return loss

    def validation_step(self, batch, batch_idx):
        micro_batch_size = batch[0].shape[0]
        device = batch[0].device
        data_idx = int(batch_idx * micro_batch_size)
        for i, t in enumerate(range(self.oc.vis.denoise_t_step - 1, self.num_timesteps, self.oc.vis.denoise_t_step)):
            t_batch = torch.ones((micro_batch_size, ), dtype=torch.long, device=device) * t
            loss, loss_dict, verbose_dict = self(batch, t=t_batch, return_verbose=True)
            self.log("val_loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
            loss_t_dict = {f"val/{key}_t{t}": val for key, val in loss_dict.items()}
            self.log_dict(loss_t_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
            loss_dict = {f"val/{key}": val for key, val in loss_dict.items()}
            self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
            if not self.eval_example_only or data_idx in self.val_example_data_idx_list:
                out_seq, in_seq_dict, energy_dict = self.get_input(batch)
                target = energy_dict["energy"][0].detach().float().cpu().numpy()
                pred = np.zeros_like(target)
                pred[-self.oc.model.guide_obj.out_len:, :] = verbose_dict["pred"][0].detach().float().cpu().numpy()
                self.save_vis_step_end(
                    data_idx=data_idx,
                    target_seq=target,
                    pred_seq=pred,
                    mode="val",
                    suffix=f"_t{int(verbose_dict['t'][0].item())}")
            self.valid_mse_list[i](verbose_dict["pred"], verbose_dict["target"])
            self.valid_mae_list[i](verbose_dict["pred"], verbose_dict["target"])
            self.valid_mse(verbose_dict["pred"], verbose_dict["target"])
            self.valid_mae(verbose_dict["pred"], verbose_dict["target"])
            
    def validation_epoch_end(self, outputs):
        valid_mse = self.valid_mse.compute()
        valid_mae = self.valid_mae.compute()
        valid_loss = valid_mse

        self.log("valid_loss_epoch", valid_loss, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
        self.log(f"valid_mse_epoch", valid_mse, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
        self.log(f"valid_mae_epoch", valid_mae, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
        self.valid_mse.reset()
        self.valid_mae.reset()
        for i, t in enumerate(range(self.oc.vis.denoise_t_step - 1, self.num_timesteps, self.oc.vis.denoise_t_step)):
            valid_mse = self.valid_mse_list[i].compute()
            valid_mae = self.valid_mae_list[i].compute()
            valid_loss = valid_mse
    
            self.log("valid_loss_epoch", valid_loss, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
            self.log(f"valid_mse_epoch_t{t}", valid_mse, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
            self.log(f"valid_mae_epoch_t{t}", valid_mae, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
            self.valid_mse_list[i].reset()
            self.valid_mae_list[i].reset()

    def test_step(self, batch, batch_idx):
        micro_batch_size = batch[0].shape[0]
        device = batch[0].device
        data_idx = int(batch_idx * micro_batch_size)
        for i, t in enumerate(range(self.oc.vis.denoise_t_step - 1, self.num_timesteps, self.oc.vis.denoise_t_step)):
            t_batch = torch.ones((micro_batch_size,), dtype=torch.long, device=device) * t
            loss, loss_dict, verbose_dict = self(batch, t=t_batch, return_verbose=True)
            self.log("test_loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
            loss_t_dict = {f"val/{key}_t{t}": val for key, val in loss_dict.items()}
            self.log_dict(loss_t_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
            loss_dict = {f"val/{key}": val for key, val in loss_dict.items()}
            self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
            if not self.eval_example_only or data_idx in self.test_example_data_idx_list:
                out_seq, in_seq_dict, energy_dict = self.get_input(batch)
                target = energy_dict["energy"][0].detach().float().cpu().numpy()
                pred = np.zeros_like(target)
                pred[-self.oc.model.guide_obj.out_len:, :] = verbose_dict["pred"][0].detach().float().cpu().numpy()
                self.save_vis_step_end(
                    data_idx=data_idx,
                    target_seq=target,
                    pred_seq=pred,
                    mode="test",
                    suffix=f"_t{int(verbose_dict['t'][0].item())}")
            self.test_mse_list[i](verbose_dict["pred"], verbose_dict["target"])
            self.test_mae_list[i](verbose_dict["pred"], verbose_dict["target"])
            self.test_mse(verbose_dict["pred"], verbose_dict["target"])
            self.test_mae(verbose_dict["pred"], verbose_dict["target"])

    def test_epoch_end(self, outputs):
        test_mse = self.test_mse.compute()
        test_mae = self.test_mae.compute()
        test_loss = test_mse

        self.log("test_loss_epoch", test_loss, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
        self.log(f"test_mse_epoch", test_mse, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
        self.log(f"test_mae_epoch", test_mae, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
        self.test_mse.reset()
        self.test_mae.reset()
        for i, t in enumerate(
                range(self.oc.vis.denoise_t_step - 1, self.num_timesteps, self.oc.vis.denoise_t_step)):
            test_mse = self.test_mse_list[i].compute()
            test_mae = self.test_mae_list[i].compute()
            test_loss = test_mse

            self.log("test_loss_epoch", test_loss, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
            self.log(f"test_mse_epoch_t{t}", test_mse, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
            self.log(f"test_mae_epoch_t{t}", test_mae, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
            self.test_mse_list[i].reset()
            self.test_mae_list[i].reset()

    @torch.no_grad()
    def save_vis_step_end(
            self,
            data_idx: int,
            target_seq: np.ndarray,
            pred_seq: np.ndarray,
            mode: str = "train",
            prefix: str = "",
            suffix: str = "", ):
        r"""
        Parameters
        ----------
        data_idx
        target_energy, pred_energy: np.ndarray
            shape = (T, 2)
        mode:   str
        """
        if mode == "train":
            example_data_idx_list = self.train_example_data_idx_list
        elif mode == "val":
            example_data_idx_list = self.val_example_data_idx_list
        elif mode == "test":
            example_data_idx_list = self.test_example_data_idx_list
        else:
            raise ValueError(f"Wrong mode {mode}! Must be in ['train', 'val', 'test'].")
        if data_idx in example_data_idx_list:
            png_save_name = f"{prefix}{mode}_epoch_{self.current_epoch}_data_{data_idx}{suffix}.png"
            vis_nbody_energy(
                save_path=os.path.join(self.example_save_dir, png_save_name),
                KE=target_seq[:, 0],
                PE=target_seq[:, 1],
                KE_color='red',
                PE_color='blue',
                sum_color='black',
                marker="s",
                marker_size=8,
                pred_KE=pred_seq[:, 0],
                pred_PE=pred_seq[:, 1],
                pred_KE_color='red',
                pred_PE_color='blue',
                pred_sum_color='black',
                pred_marker="+",
                pred_marker_size=8, )

def get_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument('--save', default='tmp_nbody_guide', type=str)
    parser.add_argument('--nodes', default=1, type=int,
                        help="Number of nodes in DDP training.")
    parser.add_argument('--gpus', default=1, type=int,
                        help="Number of GPUS per node in DDP training.")
    parser.add_argument('--cfg', default=None, type=str)
    parser.add_argument('--test', action='store_true')
    parser.add_argument('--ckpt_name', default=None, type=str,
                        help='The model checkpoint trained on N-body MNIST.')
    return parser

def main():
    parser = get_parser()
    args = parser.parse_args()
    if args.cfg is not None:
        oc_from_file = OmegaConf.load(open(args.cfg, "r"))
        dataset_cfg = OmegaConf.to_object(oc_from_file.dataset)
        total_batch_size = oc_from_file.optim.total_batch_size
        micro_batch_size = oc_from_file.optim.micro_batch_size
        max_epochs = oc_from_file.optim.max_epochs
        seed = oc_from_file.optim.seed
        float32_matmul_precision = oc_from_file.optim.float32_matmul_precision
    else:
        dataset_cfg = OmegaConf.to_object(NbodyGuidancePLModule.get_dataset_config())
        micro_batch_size = 1
        total_batch_size = int(micro_batch_size * args.nodes * args.gpus)
        max_epochs = None
        seed = 0
        float32_matmul_precision = "high"
    torch.set_float32_matmul_precision(float32_matmul_precision)
    seed_everything(seed, workers=True)
    dm = NbodyGuidancePLModule.get_nbody_datamodule(
        dataset_oc=dataset_cfg,
        micro_batch_size=micro_batch_size,
        num_workers=8, )
    dm.prepare_data()
    dm.setup()
    accumulate_grad_batches = total_batch_size // (micro_batch_size * args.nodes * args.gpus)
    total_num_steps = NbodyGuidancePLModule.get_total_num_steps(
        epoch=max_epochs,
        num_samples=dm.num_train_samples,
        total_batch_size=total_batch_size,
    )
    pl_module = NbodyGuidancePLModule(
        total_num_steps=total_num_steps,
        save_dir=args.save,
        oc_file=args.cfg)
    trainer_kwargs = pl_module.set_trainer_kwargs(
        devices=args.gpus,
        num_nodes=args.nodes,
        accumulate_grad_batches=accumulate_grad_batches,
    )
    trainer = Trainer(**trainer_kwargs)
    if args.test:
        if args.ckpt_name is not None:
            ckpt_path = os.path.join(pl_module.save_dir, "checkpoints", args.ckpt_name)
        else:
            ckpt_path = None
        trainer.test(model=pl_module,
                     datamodule=dm,
                     ckpt_path=ckpt_path)
    else:
        if args.ckpt_name is not None:
            ckpt_path = os.path.join(pl_module.save_dir, "checkpoints", args.ckpt_name)
            if not os.path.exists(ckpt_path):
                warnings.warn(f"ckpt {ckpt_path} not exists! Start training from epoch 0.")
                ckpt_path = None
        else:
            ckpt_path = None
        trainer.fit(model=pl_module,
                    datamodule=dm,
                    ckpt_path=ckpt_path)
        # save state_dict of the knowledge control network
        pl_ckpt = pl_load(path_or_url=trainer.checkpoint_callback.best_model_path,
                          map_location=torch.device("cpu"))
        pl_state_dict = pl_ckpt["state_dict"]
        model_kay = "torch_nn_module."
        state_dict = OrderedDict()
        unexpected_dict = OrderedDict()
        for key, val in pl_state_dict.items():
            if key.startswith(model_kay):
                state_dict[key.replace(model_kay, "")] = val
            else:
                unexpected_dict[key] = val
        torch.save(state_dict, os.path.join(pl_module.save_dir, "checkpoints", pytorch_state_dict_name))
        # test
        trainer.test(ckpt_path="best",
                     datamodule=dm)

if __name__ == "__main__":
    main()
