from apo_precond.preconditioner import *
from collections import defaultdict, abc as container_abcs
from copy import deepcopy
from itertools import chain
from apo_precond.utils import *

import torch.optim as optim
import torch


class ApoPrecondOptimizer(optim.Optimizer):
    def __init__(self, model, lr=0.001, precond_lr=0.9, momentum=0, weight_decay=0, nesterov=False,
                 meta_objective="", meta_optimizer="adam", meta_lr=1e-3, parameterization="ekfac_psd", scale=1.,
                 warmup_step=0, lamb_wsp=0, lamb_fsp=0, fsp_fnc="kl", initial_optimizer="sgdm",
                 graft=False, debug_mode=False):
        if lr < 0.0:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if meta_lr < 0.0:
            raise ValueError("Invalid meta learning rate: {}".format(lr))
        if momentum < 0.0:
            raise ValueError("Invalid momentum value: {}".format(momentum))
        if weight_decay < 0.0:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
        if lamb_wsp < 0.0:
            raise ValueError("Invalid weight-space proximity value: {}".format(lamb_wsp))
        if lamb_fsp < 0.0:
            raise ValueError("Invalid function-space proximity value: {}".format(lamb_fsp))
        defaults = dict(lr=lr, precond_lr=precond_lr, momentum=momentum, weight_decay=weight_decay,
                        nesterov=nesterov, warmup_step=warmup_step)
        self.model = model

        if meta_objective not in ["", "m", "w", "mw"]:
            raise ValueError("Invalid meta_objective type: {}".format(meta_objective))
        self.meta_objective = meta_objective
        self.meta_optimizer = meta_optimizer
        self.meta_lr = meta_lr
        self.parameterization = parameterization
        self.scale = scale

        self.warmup_step = warmup_step
        self.lamb_wsp, self.lamb_fsp, self.fsp_fnc = lamb_wsp, lamb_fsp, fsp_fnc
        self.initial_optimizer = initial_optimizer
        self.graft = graft
        self.debug_mode = debug_mode

        super(ApoPrecondOptimizer, self).__init__(model.parameters(), defaults)
        self._init_state()
        self.global_step = 0

    def _init_state(self):
        self.state["model"] = self.model
        for group in self.param_groups:
            for p in group["params"]:
                state = self.state[p]

                state["momentum"] = torch.zeros_like(p.data, device=p.get_device())
                state["scratch"] = torch.zeros_like(p.data, device=p.get_device())
                state["scratch"].data.copy_(p.data)
                state["preconditioner"] = Preconditioner(p, self.parameterization, self.scale, self.debug_mode)
                state["meta_step"] = 0
                state["step"] = 0

                if self.meta_optimizer == "sgd":
                    state["optimizer"] = optim.SGD(state["preconditioner"].parameters(), lr=self.meta_lr)
                elif self.meta_optimizer == "adam":
                    state["optimizer"] = optim.Adam(state["preconditioner"].parameters(), lr=self.meta_lr)
                elif self.meta_optimizer == "rmsprop":
                    state["optimizer"] = optim.RMSprop(state["preconditioner"].parameters(), lr=self.meta_lr)
                else:
                    raise ValueError("Invalid meta optimizer type: {}".format(self.meta_optimizer))

                lr = self.param_groups[0]["lr"]
                wd = self.param_groups[0]["weight_decay"]
                nesterov = self.param_groups[0]["nesterov"]

                if self.initial_optimizer == "sgd":
                    state["initial_optimizer"] = optim.SGD([p], lr=lr, momentum=0.,
                                                           weight_decay=wd, nesterov=nesterov)
                elif self.initial_optimizer == "sgdm":
                    state["initial_optimizer"] = optim.SGD([p], lr=lr, momentum=0.9,
                                                           weight_decay=wd, nesterov=nesterov)
                elif self.initial_optimizer == "adam":
                    state["initial_optimizer"] = optim.Adam([p], lr=lr, weight_decay=wd)
                elif self.initial_optimizer == "adagrad":
                    state["initial_optimizer"] = optim.Adagrad([p], lr=lr, weight_decay=wd)
                else:
                    raise ValueError("Invalid initial optimizer type: {}".format(self.initial_optimizer))

    def _get_meta_parameters(self):
        meta_parameters = []
        for group in self.param_groups:
            for p in group["params"]:
                state = self.state[p]
                meta_parameters.extend(list(state["preconditioner"].parameters()))
        return meta_parameters

    def _get_initial_optimizers(self):
        initial_optimizers = []
        for group in self.param_groups:
            for p in group["params"]:
                state = self.state[p]
                initial_optimizers.append(state["initial_optimizer"])
        return initial_optimizers

    def _gather_grads_dict(self):
        grads_dict = {}
        for name, param in self.model.named_parameters():
            grads_dict[name] = param.grad
        return grads_dict

    def _organize_grads_dict(self, grads):
        i = 0
        grads_dict = {}
        for name, param in self.model.named_parameters():
            grads_dict[name] = grads[i]
            i += 1
        return grads_dict

    def inject_meta_parameters(self, grads_dict, bnn=False):
        original_named_parameters = make_functional(self.model)
        detached_named_parameters_dict = {}
        for name, param in original_named_parameters:
            detached_named_parameters_dict[name] = param.clone().detach()

        updated_parameters_lst = []
        precondition_grads_lst = []
        for name, param in original_named_parameters:
            state = self.state[param]
            preconditioner = state["preconditioner"]
            preconditioned_grad = preconditioner.precondition_gradient(grads_dict[name].detach())
            modified_grad = preconditioned_grad

            wd = self.param_groups[0]["weight_decay"]
            momentum = self.param_groups[0]["momentum"]

            if "w" in self.meta_objective:
                modified_grad = preconditioned_grad + wd * detached_named_parameters_dict[name]
            if "m" in self.meta_objective:
                modified_grad = state["momentum"] * momentum + modified_grad

            set_attr(self.model, name.split("."), detached_named_parameters_dict[name] - modified_grad)
            updated_parameters_lst.append(detached_named_parameters_dict[name] - modified_grad)
            precondition_grads_lst.append(preconditioned_grad)
        return original_named_parameters, precondition_grads_lst, updated_parameters_lst

    def inject_parameters(self, params):
        for name, param in params:
            set_attr(self.model, name.split("."), param)

    def meta_loss(self, inputs, targets, criterion, grads):
        grads_dict = self._organize_grads_dict(grads)
        original_parameters, _, _ = self.inject_meta_parameters(grads_dict)
        meta_outputs = self.model(*inputs)
        meta_loss = criterion(meta_outputs, targets)
        self.inject_parameters(original_parameters)
        return meta_loss

    def set_meta_optim_lr(self, lr):
        for group in self.param_groups:
            for p in group["params"]:
                optimizer = self.state[p]["optimizer"]
                for param_group in optimizer.param_groups:
                    param_group["lr"] = lr

    def get_learning_rate(self):
        for group in self.param_groups:
            lr = group["lr"]
            precond_lr = group["precond_lr"]
            if precond_lr == 0:
                precond_lr = lr
            warmup_step = group["warmup_step"]

            for p in group["params"]:
                state = self.state[p]
                meta_step = state["meta_step"]

                state["scratch"].data.copy_(p.data)
                if meta_step >= warmup_step:
                    return precond_lr
                else:
                    return lr

    def meta_step(self, inputs, targets, criterion, fsp_outputs=None, fsp_inputs=None, grads=None):
        if grads is None:
            grads_dict = self._gather_grads_dict()
        else:
            grads_dict = self._organize_grads_dict(grads)

        original_parameters_lst, precondition_grads_lst, updated_parameters_lst\
            = self.inject_meta_parameters(grads_dict)

        meta_outputs = self.model(*inputs)
        pure_meta_loss = criterion(meta_outputs, targets)

        # Add weight decay terms
        wd = self.param_groups[0]["weight_decay"]
        meta_loss = pure_meta_loss + 0.5 * wd * sum([torch.sum(t ** 2.) for t in updated_parameters_lst])

        # Add weight-space proximity terms
        prox_wsp = 0.
        for p_grad in precondition_grads_lst:
            prox_wsp += torch.sum(p_grad ** 2.)
        meta_loss += self.lamb_wsp * prox_wsp

        # Add function-space proximity terms
        if fsp_inputs is None or fsp_outputs is None:
            prox_fsp = 0.
        else:
            meta_fsp_outputs = self.model(*fsp_inputs)
            if self.fsp_fnc == "kl":
                prox_fsp = kd_prox(meta_fsp_outputs, fsp_outputs.clone().detach())
            elif self.fsp_fnc == "euc":
                prox_fsp = euc_prox(meta_fsp_outputs, fsp_outputs.clone().detach())
            else:
                raise ValueError("Invalid function-space proximity function {}".format(self.fsp_fnc))
            meta_loss += self.lamb_fsp * prox_fsp

        # Clean up gradients
        for group in self.param_groups:
            for p in group["params"]:
                self.state[p]["initial_optimizer"].zero_grad()
                self.state[p]["optimizer"].zero_grad()

        meta_loss.backward()
        for group in self.param_groups:
            for p in group["params"]:
                self.state[p]["optimizer"].step()
                self.state[p]["meta_step"] += 1

        self.inject_parameters(original_parameters_lst)

        # Clean up gradients
        self.zero_grad()
        for group in self.param_groups:
            for p in group["params"]:
                self.state[p]["initial_optimizer"].zero_grad()
                self.state[p]["optimizer"].zero_grad()

        return {
            "meta_loss": pure_meta_loss.item(),
            "prox_wsp": prox_wsp.item() if not isinstance(prox_wsp, float) else prox_wsp,
            "prox_fsp": prox_fsp.item() if not isinstance(prox_fsp, float) else prox_fsp
        }

    def meta_step_eval(self, inputs, targets, criterion, fsp_outputs=None, fsp_inputs=None, grads=None):
        if grads is None:
            grads_dict = self._gather_grads_dict()
        else:
            grads_dict = self._organize_grads_dict(grads)

        original_parameters_lst, precondition_grads_lst, updated_parameters_lst\
            = self.inject_meta_parameters(grads_dict)

        self.model.eval()
        meta_outputs = self.model(*inputs)
        pure_meta_loss = criterion(meta_outputs, targets)

        # Add weight decay terms
        wd = self.param_groups[0]["weight_decay"]
        meta_loss = pure_meta_loss + 0.5 * wd * sum([torch.sum(t ** 2.) for t in updated_parameters_lst])

        # Add weight-space proximity terms
        prox_wsp = 0.
        for p_grad in precondition_grads_lst:
            prox_wsp += torch.sum(p_grad ** 2.)
        meta_loss += self.lamb_wsp * prox_wsp

        # Add function-space proximity terms
        if fsp_inputs is None or fsp_outputs is None:
            prox_fsp = 0.
        else:
            meta_fsp_outputs = self.model(*fsp_inputs)
            if self.fsp_fnc == "kl":
                prox_fsp = kd_prox(meta_fsp_outputs, fsp_outputs.clone().detach())
            elif self.fsp_fnc == "euc":
                prox_fsp = euc_prox(meta_fsp_outputs, fsp_outputs.clone().detach())
            else:
                raise ValueError("Invalid function-space proximity function {}".format(self.fsp_fnc))
            meta_loss += self.lamb_fsp * prox_fsp

        # Clean up gradients
        for group in self.param_groups:
            for p in group["params"]:
                self.state[p]["initial_optimizer"].zero_grad()
                self.state[p]["optimizer"].zero_grad()

        meta_loss.backward()
        for group in self.param_groups:
            for p in group["params"]:
                self.state[p]["optimizer"].step()
                self.state[p]["meta_step"] += 1

        self.inject_parameters(original_parameters_lst)

        # Clean up gradients
        self.zero_grad()
        for group in self.param_groups:
            for p in group["params"]:
                self.state[p]["initial_optimizer"].zero_grad()
                self.state[p]["optimizer"].zero_grad()
        self.model.train()

        return {
            "meta_loss": pure_meta_loss.item(),
            "prox_wsp": prox_wsp.item() if not isinstance(prox_wsp, float) else prox_wsp,
            "prox_fsp": prox_fsp.item() if not isinstance(prox_fsp, float) else prox_fsp
        }

    @torch.no_grad()
    def step(self, closure=None):
        for group in self.param_groups:
            lr = group["lr"]
            precond_lr = group["precond_lr"]
            if precond_lr == 0:
                # If precond_lr is not provided, simply use the current lr
                precond_lr = lr
            weight_decay = group["weight_decay"]
            nesterov = group["nesterov"]
            warmup_step = group["warmup_step"]

            for p in group["params"]:
                if p.grad is None:
                    continue
                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError("Sparse tensor not supported.")

                state = self.state[p]
                initial_optimizer = state["initial_optimizer"]
                preconditioner = state["preconditioner"]
                meta_step = state["meta_step"]

                state["scratch"].data.copy_(p.data)
                if meta_step >= warmup_step:
                    if self.graft:
                        initial_optimizer.step()
                        # Revert the step by initial optimizer
                        update_step = p.data - state["scratch"].data
                        state["m_norm"] = torch.linalg.norm(update_step)
                        p.data.copy_(state["scratch"].data)

                    # Precondition the gradient
                    apo_grad = preconditioner.precondition_gradient(grad)
                    lr = precond_lr

                    # Weight decay
                    if weight_decay != 0.0:
                        apo_grad.add_(p.data, alpha=weight_decay)

                    # Momentum and Nesterov momentum, if needed
                    state["momentum"].mul_(group["momentum"]).add_(apo_grad)
                    momentum_update = state["momentum"]
                    wd_update = apo_grad
                    if nesterov:
                        momentum_update.mul_(group["momentum"]).add_(wd_update)

                    if self.graft:
                        apo_grad_norm = torch.linalg.norm(lr * momentum_update)
                        rescale_factor = state["m_norm"] / (apo_grad_norm + 1e-16)
                        momentum_update.mul_(rescale_factor)
                        p.data.add_(momentum_update, alpha=-1.)
                    else:
                        p.data.add_(momentum_update, alpha=-lr)

                else:
                    apo_grad = preconditioner.precondition_gradient(grad)
                    if weight_decay != 0.0:
                        apo_grad.add_(p.data, alpha=weight_decay)
                    state["momentum"].mul_(group["momentum"]).add_(apo_grad)

                    initial_optimizer.step()

                state["step"] += 1
        self.global_step += 1

    def state_dict(self):
        # Save order indices instead of Tensors
        param_mappings = {}
        start_index = 0

        def pack_group(group):
            nonlocal start_index
            packed = {k: v for k, v in group.items() if k != "params"}
            param_mappings.update({id(p): i for i, p in enumerate(group["params"], start_index)
                                   if id(p) not in param_mappings})
            packed["params"] = [param_mappings[id(p)] for p in group["params"]]
            start_index += len(packed["params"])
            return packed
        param_groups = [pack_group(g) for g in self.param_groups]
        # Remap state to use order indices as keys
        packed_state = {(param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): v
                        for k, v in self.state.items()}
        return {
            "state": packed_state,
            "param_groups": param_groups,
            "global_step": self.global_step
        }

    def load_state_dict(self, state_dict):
        # deepcopy, to be consistent with module API
        state_dict = deepcopy(state_dict)
        # Validate the state_dict
        groups = self.param_groups
        saved_groups = state_dict["param_groups"]

        if len(groups) != len(saved_groups):
            raise ValueError("loaded state dict has a different number of "
                             "parameter groups")
        param_lens = (len(g["params"]) for g in groups)
        saved_lens = (len(g["params"]) for g in saved_groups)
        if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
            raise ValueError("loaded state dict contains a parameter group "
                             "that doesn't match the size of optimizer's group")

        # Update the state
        id_map = {old_id: p for old_id, p in
                  zip(chain.from_iterable((g["params"] for g in saved_groups)),
                      chain.from_iterable((g["params"] for g in groups)))}

        def cast(param, value):
            r"""Make a deep copy of value, casting all tensors to device of param."""
            if isinstance(value, torch.Tensor):
                # Floating-point types are a bit special here. They are the only ones
                # that are assumed to always match the type of params.
                if param.is_floating_point():
                    value = value.to(param.dtype)
                value = value.to(param.device)
                return value
            elif isinstance(value, dict):
                return {k: cast(param, v) for k, v in value.items()}
            elif isinstance(value, container_abcs.Iterable):
                return type(value)(cast(param, v) for v in value)
            else:
                return value

        # Copy state assigned to params (and cast tensors to appropriate types).
        # State that is not assigned to params is copied as is (needed for
        # backward compatibility).
        state = defaultdict(dict)
        for k, v in state_dict["state"].items():
            if k in id_map:
                param = id_map[k]
                state[param] = cast(param, v)

                z = list(state[param]["initial_optimizer"].state.values())[0]
                state[param]["initial_optimizer"].state = defaultdict(dict)
                state[param]["initial_optimizer"].state[param] = cast(param, z)
                state[param]["initial_optimizer"].param_groups[0]["params"] = [param]
            else:
                state[k] = v

        # Update parameter groups, setting their "params" value
        def update_group(group, new_group):
            new_group["params"] = group["params"]
            return new_group
        param_groups = [
            update_group(g, ng) for g, ng in zip(groups, saved_groups)]
        self.__setstate__({"state": state, "param_groups": param_groups})
        self.global_step = state_dict["global_step"]
