from typing import Callable, Dict, List, Tuple

from absl import flags
import asdl
import torch
import torch.distributed.nn as dist_nn
import wandb

from algorithmic_efficiency import spec
from algorithmic_efficiency.pytorch_utils import pytorch_setup
from reference_algorithms.target_setting_algorithms.data_selection import \
    data_selection  # pylint: disable=unused-import
from reference_algorithms.target_setting_algorithms.get_batch_size import \
    get_batch_size
from reference_algorithms.target_setting_algorithms.pytorch_nadamw import \
    init_optimizer_state as init_nadamw
from reference_algorithms.target_setting_algorithms.pytorch_nesterov import \
    init_optimizer_state as init_nesterov

FLAGS = flags.FLAGS
USE_PYTORCH_DDP, RANK = pytorch_setup()[:2]


def get_base_optimizer_init_fn(workload_name: str) -> Callable:
  if workload_name == 'ogbg':
    return init_nesterov
  elif workload_name == 'imagenet_vit':
    return init_nadamw
  else:
    raise ValueError('Only ogbg and imagenet_vit workloads are supported here.')  


def init_optimizer_state(workload: spec.Workload,
                         model_params: spec.ParameterContainer,
                         model_state: spec.ModelAuxiliaryState,
                         hyperparameters: spec.Hyperparameters,
                         rng: spec.RandomState) -> spec.OptimizerState:
  """Creates a first-order base optimizer and a K-FAC gradient maker."""

  # Initialize base optimizer.
  base_optimizer_init_fn = get_base_optimizer_init_fn(FLAGS.workload)
  optimizer_state = base_optimizer_init_fn(workload,
                                           model_params,
                                           model_state,
                                           hyperparameters,
                                           rng)

  # K-FAC settings + gradient maker.
  ngd_config = asdl.PreconditioningConfig(
      data_size=get_batch_size(FLAGS.workload),
      damping=hyperparameters.damping,
      preconditioner_upd_interval=hyperparameters.preconditioner_upd_interval,
      curvature_upd_interval=hyperparameters.curvature_upd_interval)
  optimizer_state['gradient_maker'] = asdl.KfacGradientMaker(
      model_params,
      ngd_config,
      kfac_linaer=hyperparameters.kfac_linear)

  return optimizer_state


def update_params(workload: spec.Workload,
                  current_param_container: spec.ParameterContainer,
                  current_params_types: spec.ParameterTypeTree,
                  model_state: spec.ModelAuxiliaryState,
                  hyperparameters: spec.Hyperparameters,
                  batch: Dict[str, spec.Tensor],
                  loss_type: spec.LossType,
                  optimizer_state: spec.OptimizerState,
                  eval_results: List[Tuple[int, float]],
                  global_step: int,
                  rng: spec.RandomState) -> spec.UpdateReturn:
  """Return (updated_optimizer_state, updated_params, updated_model_state)."""
  del current_params_types
  del loss_type
  del eval_results

  if global_step == 0 and RANK == 0:
    from dataclasses import asdict
    wandb.config.update(asdict(optimizer_state['gradient_maker'].config))

  current_model = current_param_container
  grad_maker = optimizer_state['gradient_maker']
  optimizer_state['optimizer'].zero_grad()

  def _loss_fn(batch):
    logits, _ = workload.model_fn(
        params=current_model,
        augmented_and_preprocessed_input_batch=batch,
        model_state=model_state,
        mode=spec.ForwardPassMode.TRAIN,
        rng=rng,
        update_batch_norm=True)
    mask = batch.get('weights')
    label_smoothing = (
        hyperparameters.label_smoothing if hasattr(hyperparameters,
                                                   'label_smoothing') else 0.0)
    loss_dict = workload.loss_fn(
        batch['targets'], logits, mask, label_smoothing=label_smoothing)
    summed_loss = loss_dict['summed']
    n_valid_examples = loss_dict['n_valid_examples']
    if USE_PYTORCH_DDP:
      # Use dist_nn.all_reduce to ensure correct loss and gradient scaling.
      summed_loss = dist_nn.all_reduce(summed_loss)
      n_valid_examples = dist_nn.all_reduce(n_valid_examples)
    loss = summed_loss / n_valid_examples
    return logits, loss

  # Precondition the gradient with K-FAC.
  dummy_y = grad_maker.setup_model_call(_loss_fn, batch)
  grad_maker.setup_loss_repr(dummy_y[1])
  _ = grad_maker.forward_and_backward()

  if hasattr(hyperparameters, 'grad_clip'):
    grad_clip = hyperparameters.grad_clip
    torch.nn.utils.clip_grad_norm_(
        current_model.parameters(), max_norm=grad_clip)
  optimizer_state['optimizer'].step()
  optimizer_state['scheduler'].step()

  return optimizer_state, current_param_container, None
