import torch
import logging
import numpy as np
import torch.nn as nn

from typing import Callable, List

from .pruning_handle_base import BasePruningHandle


__all__ = [
    "SPDYProfileSeeker",
    "create_sparsity_levels"
]

    
_LOGGER = logging.getLogger(__name__)

    
class SPDYProfileSeeker:

    def __init__(
        self,
        # model. databuilder and loss_fn are needed for profile evaluation
        model : nn.Module,
        data_loader_builder: Callable,
        loss_fn: nn.Module,
        pruning_handles: List[BasePruningHandle],
        # SPDY DP solver params
        num_buckets: int = 10000,
        num_rand_inits: int = 100,
        resample_perc: float = 0.1, 
        patience: int = 100
    ):
        self._model = model
        self._data_loader_builder = data_loader_builder
        self._loss_fn = loss_fn
        self._pruning_handles = pruning_handles
        # SPDY solver params
        self._num_buckets = num_buckets
        self._num_rand_inits = num_rand_inits
        self._resample_perc = resample_perc
        self._patience = patience
        # errs and budgets needed for SPDY step
        self._errs = None
        self._budgets = None
     
    def get_costs(self, coefs: np.ndarray) -> np.ndarray:
        return np.stack([
            [self._errs[i][j] * coefs[i] for j in range(self._errs.shape[1])] \
            for i in range(self._errs.shape[0])
        ])

    @torch.no_grad()
    def set_weights(self, solution: np.ndarray):
        for layer_id, pruning_handle in enumerate(self._pruning_handles):
            # set layer weight from database
            sparsity = pruning_handle._sparsity_levels[solution[layer_id]]
            pruning_handle.set(sparsity)

    @torch.no_grad()
    def get_loss(self):
        total_loss = 0
        num_collected = 0
        # create loader and iterate
        loader = self._data_loader_builder()
        for input_args, input_kwargs, targets in loader:
            outputs = self._model(*input_args, **input_kwargs)
            total_loss += len(targets) * self._loss_fn(outputs, targets).item()
            num_collected += len(targets)
        return total_loss / num_collected

    def evaluate(self, coefs: np.ndarray):
        # generate costs
        costs = self.get_costs(coefs)
        # find solution to DP problem
        solution = self.solve(costs)
        # construct model
        self.set_weights(solution)
        return solution, self.get_loss()
        
    def solve(self, costs: np.ndarray) -> list:
        sp_levels = costs.shape[1]
        Ds = np.full((len(costs), self._num_buckets + 1), float('inf'))
        Ps = np.full((len(costs), self._num_buckets + 1), -1)

        for i in range(sp_levels):
            if costs[0][i] < Ds[0][self._budgets[0][i]]:
                Ds[0][self._budgets[0][i]] = costs[0][i]
                Ps[0][self._budgets[0][i]] = i

        for module_id in range(1, len(Ds)):
            for sparsity in range(sp_levels):
                budget = self._budgets[module_id][sparsity]
                score = costs[module_id][sparsity]
                if budget == 0:
                    tmp = Ds[module_id - 1] + score
                    better = tmp < Ds[module_id]
                    if np.sum(better):
                        Ds[module_id][better] = tmp[better]
                        Ps[module_id][better] = sparsity
                    continue
                if budget > self._num_buckets:
                    continue
                tmp = Ds[module_id - 1][:-budget] + score
                better = tmp < Ds[module_id][budget:]
                if np.sum(better):
                    Ds[module_id][budget:][better] = tmp[better]
                    Ps[module_id][budget:][better] = sparsity

        score = np.min(Ds[-1, :])
        budget = np.argmin(Ds[-1, :])
        
        solution = []
        for module_id in range(len(Ds) - 1, -1, -1):
            solution.append(Ps[module_id][budget])
            budget -= self._budgets[module_id][solution[-1]]
        solution.reverse()

        return solution

    # TODO add sparsity multliplier
    def search(self, errs: dict, budgets: dict, target_budget: float) -> List[int]:
        # stack dict of errs to 2d array
        self._errs = np.stack([v for _, v in errs.items()])
        # stack dict of budgets to 2d array
        self._budgets = np.stack([v for _, v in budgets.items()])
        # get size of one bucket
        bucket_size = target_budget / self._num_buckets
        # quantize costs to buckets
        self._budgets = (self._budgets / bucket_size).astype(int)

        num_layers = len(self._pruning_handles)
        num_evaluations = 0
      
        _LOGGER.info('Finding init.')
        # init values
        best_coefs = None
        best_score = float('inf')
        best_solution = None
        for _ in range(self._num_rand_inits):
            coefs = np.random.uniform(0, 1, size=num_layers)
            solution, score = self.evaluate(coefs)
            num_evaluations += 1
            _LOGGER.info(f'Evaluation {num_evaluations} {score:.4f} (best {best_score:.4f})')
            if score < best_score:
                best_score = score
                best_coefs = coefs
                best_solution = solution
        
        _LOGGER.info('Running local search.')
        for resamplings in range(int(self._resample_perc * num_layers), 0, -1):
            _LOGGER.info(f'Trying {resamplings} resamplings ...')
            improved = True
            while improved: 
                improved = False
                for _ in range(self._patience):
                    coefs = best_coefs.copy()
                    for _ in range(resamplings):
                        coefs[np.random.randint(0, num_layers - 1)] = np.random.uniform(0, 1)
                    solution, score = self.evaluate(coefs)
                    num_evaluations += 1
                    _LOGGER.info(f'Evaluation {num_evaluations} {score:.4f} (best {best_score:.4f})')
                    if score < best_score:
                        best_score = score
                        best_coefs = coefs
                        best_solution = solution
                        improved = True
                        break

        _LOGGER.info('SPDY search completed.')        

        return best_solution


def create_sparsity_levels(
    level_inter_func: str,
    min_sparsity_level: float,
    max_sparsity_level: float,
    num_sparsity_levels: int,  
):
    if level_inter_func == 'exp':
        log_min = np.log2(1 - min_sparsity_level)
        log_max = np.log2(1 - max_sparsity_level)
        sparsity_levels = 1 - np.logspace(log_min, log_max, num=num_sparsity_levels, base = 2)
    elif level_inter_func == 'linear':
        sparsity_levels = np.linspace(min_sparsity_level, max_sparsity_level, num_sparsity_levels)
    elif level_inter_func == 'cubic':
        sparsity_levels = min_sparsity_level + \
            (max_sparsity_level - min_sparsity_level) * np.linspace(0, 1, num_sparsity_levels) ** (1 / 3)
    return sparsity_levels
