from collections import defaultdict
import os
import json
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from .functional import vectorwise_quant, vectorwise_dequant, get_enable_fn_from_subconfig, \
    create_dynamic_map, create_pow_map, create_fp8_map, FP_EXPONENT_BIS_MAP
from . import training_stats
from .utils import get_metric_from_q_and_dq, get_metric_fn, get_rank
import functools


class Quantizer(object):
    def __init__(self):
        self.inited = False
        # dummy quantize fn
        self.q_and_dq_parameters_fn = lambda: None

    def is_inited(self):
        return self.inited

    def set_config(self, config):
        self.config = config
        
    def init(self, model, optimizer, config=None):
        self.inited = True
        if config is not None:
            self.config = config
        # init state
        self.state_name_list = self.config.QUANT.INIT_STATES
        named_parameters = list(model.named_parameters())
        self.p_name_list = [name for name, _ in named_parameters]
        self.state = defaultdict(dict)
        for p_name in self.p_name_list:
            for state_name in self.state_name_list:
                self.state[p_name][state_name] = dict(qx=None, metadata=None, step=0, running_config=dict())
        
        # init qconfig_fn
        self.qconfig_fn = self._init_qconfig_fn(self.config)
        self.qmap = {}
        
        # init quantization interface for various states(p, m, g)
        ## assign name to each parameters
        for name, p in model.named_parameters():
            p.p_name = name
        self.q_and_dq_parameters_fn = functools.partial(q_and_dq_parameters, model=model)
        if optimizer is not None:
            self.q_and_dq_optimizer_states_fn = functools.partial(q_and_dq_optimizer_states, model=model, optimizer=optimizer)
        else:
            self.q_and_dq_optimizer_states_fn = None
        # if isinstance(model, DDP):
            # self.dist_quant_consistency_checker_fn =  functools.partial(dist_quant_consistency_checker, model=model, optimizer=optimizer, config=self.config)
            # self.dist_sample_diversity_checker = dist_sample_diversity_checker
            # model.register_comm_hook(None, q_and_dq_gradients_hook)
        # else:
        self.dist_quant_consistency_checker_fn = None
        self.dist_sample_diversity_checker = None

    def _init_qconfig_fn(self, config):

        def enable_fn(p_name, state_name, x):
            root_enable = self._sanity_check(p_name, state_name)
            if not root_enable:
                return False
            subconfig = _get_subconfig(config, state_name)
            fn = get_enable_fn_from_subconfig(subconfig=subconfig)
            return fn(p_name, state_name, x)

        def metadata_fn(p_name, state_name, x):
            subconfig = _get_subconfig(config, state_name)
            md = dict(b=subconfig.BITS,
                scale_type=subconfig.SCALE_TYPE.DEFAULT,
                quant_type=subconfig.QUANT_TYPE.DEFAULT,
                round_type=subconfig.ROUND_TYPE,
                transform=subconfig.TRANSFORM.DEFAULT,
                gp_sz=subconfig.GROUP_SIZE,
                signed=subconfig.SIGNED,
                truncated_mode=subconfig.TRUNCATED_MODE,
                truncated_factor=subconfig.TRUNCATED_FACTOR,
                truncated_global_factor=subconfig.TRUNCATED_GLOBAL_FACTOR,
                fp16_scale=subconfig.FP16_SCALE,
            )
            # lazy init for qmap
            qmap_key = (subconfig.BITS, subconfig.SIGNED)
            if qmap_key not in self.qmap:
                self.qmap[qmap_key] = {}
            quant_type = subconfig.QUANT_TYPE.DEFAULT
            if quant_type not in self.qmap[qmap_key]:
                if quant_type == 'nonlinear':
                    self.qmap[qmap_key][quant_type] = create_dynamic_map(subconfig.SIGNED, subconfig.BITS - 1, subconfig.BIS).to(x.device)
                elif quant_type == 'power-1':
                    self.qmap[qmap_key][quant_type] = create_pow_map(subconfig.BITS, subconfig.SIGNED, 1).to(x.device)
                elif quant_type == 'power-2':
                    self.qmap[qmap_key][quant_type] = create_pow_map(subconfig.BITS, subconfig.SIGNED, 2).to(x.device)
                elif quant_type == 'power-3':
                    self.qmap[qmap_key][quant_type] = create_pow_map(subconfig.BITS, subconfig.SIGNED, 3).to(x.device)
                elif quant_type == 'float-point':
                    self.qmap[qmap_key][quant_type] = create_fp8_map(subconfig.SIGNED, FP_EXPONENT_BIS_MAP[subconfig.BITS], subconfig.BITS).to(x.device)
            md['qmap'] = self.qmap

            # supplement by p_name and x
            md['shape'] = x.shape
            if not subconfig.SCALE_TYPE.DEFAULT_ONLY:
                if p_name in subconfig.SCALE_TYPE:
                    md['quant_type'] = subconfig.SCALE_TYPE[p_name]
            if not subconfig.QUANT_TYPE.DEFAULT_ONLY:
                if p_name in subconfig.QUANT_TYPE:
                    md['quant_type'] = subconfig.QUANT_TYPE[p_name]
            if not subconfig.TRANSFORM.DEFAULT_ONLY:
                if p_name in subconfig.TRANSFORM:
                    md['transform'] = subconfig.TRANSFORM[p_name]
            
            # group size experiments
            if md['scale_type'] == 'group' and md['gp_sz'] == -1:
                num_groups = sum([*x.shape])
                md['gp_sz'] = x.numel() // num_groups
            return md

        def qconfig_fn(p_name, state_name, x):
            return enable_fn(p_name, state_name, x), metadata_fn(p_name, state_name, x)
        return qconfig_fn

    def _sanity_check(self, p_name, state_name):
        return p_name in self.state and state_name in self.state[p_name]

    def set_manual_qconfig_func(self, enable_func, metadata_func):
        # useless
        old_qconfig_func = self.qconfig_func
        def qconfig_func(p_name, state_name, x):
            old_enable, old_metadata = old_qconfig_func(p_name, state_name, x)
            enable = enable_func(p_name, state_name, x) if enable_func is not None else old_enable
            metadata = metadata_func(p_name, state_name, x) if metadata_func is not None else old_metadata
            return enable, metadata
        self.qconfig_func = qconfig_func

    def q_and_dq_optimizer_fn(self, optim=None):
        if optim is None:
            optim = self.config.TRAIN.OPTIMIZER_NAME
        if optim == 'sgd':
            self.q_and_dq_optimizer_states_fn(optimizer_name=optim, state_name='exp_avg', optimizer_state_name='momentum_buffer')
        elif optim in ['adam', 'adamw']:
            self.q_and_dq_optimizer_states_fn(optimizer_name=optim, state_name='exp_avg', optimizer_state_name='exp_avg')
            self.q_and_dq_optimizer_states_fn(optimizer_name=optim, state_name='exp_avg_sq', optimizer_state_name='exp_avg_sq')
        else:
            raise NotImplementedError
        
    def _adaptive_quant_criterion(self, p_name, state_name, x):
        state = self.state[p_name][state_name]
        enable, _ = self.qconfig_fn(p_name, state_name, x)
        if enable:
            return self.config.QUANT.SELF_IMPROVE_STEPS > 0 and state['step'] % self.config.QUANT.SELF_IMPROVE_STEPS == 0
        else:
            return False

    def _adaptive_quant_config_update(self, p_name, state_name, x, op):
        # NOTE: suppose (p_name, state_name, x) is enabled

        # init
        state = self.state[p_name][state_name]
        _, metadata = self.qconfig_fn(p_name, state_name, x)
        optimal_metric = None
        optimal_cfg = None
        subconfig = _get_subconfig(self.config, state_name)

        # search
        cfgs = []
        for scale_type in self.config.QUANT.SCALE_TYPE_CANDIDATES:
            for quant_type in self.config.QUANT.QUANT_TYPE_CANDIDATES:
                cfg = dict(scale_type=scale_type, quant_type=quant_type)
                cfgs.append(cfg)
                
        for cfg in cfgs:
            _, metadata = self.qconfig_fn(p_name, state_name, x)
            metadata.update(cfg)
            metric = get_metric_from_q_and_dq(
                x, 
                op,
                average=self.config.QUANT.SELF_IMPROVE_AVG, 
                **metadata,
            )
            if optimal_cfg is None or metric < optimal_metric:
                optimal_metric = metric
                optimal_cfg = cfg
        
        # update 
        state['running_config'].update(optimal_cfg)
        if get_rank() == 0:
            with open(os.path.join(self.config.OUTPUT, 'quant_type_dynamics.jsonl'), 'at') as f:    
                dumped = dict(p_name=p_name, state_name=state_name, step=state['step'])
                dumped.update(state['running_config'])
                f.write(json.dumps(dumped) + '\n')
                f.flush()

    @torch.no_grad()
    def quantize(self, p_name, state_name, x):        
        enable, metadata = self.qconfig_fn(p_name, state_name, x)
        if enable:
            state = self.state[p_name][state_name]
            state['step'] += 1
            metadata.update(state['running_config'])
            
            qx, metadata = vectorwise_quant(x, **metadata)
            state['qx'] = qx
            state['metadata'] = metadata

            # lpmm-stat: truncated_rate
            if self.config.QUANT.DEBUG.TRUNCATED_RATE_STAT_ITER and metadata.get('truncated_rate', None) is not None:
                training_stats.report(f"Quant/truncated_rate/{state_name}/{p_name}", metadata['truncated_rate'])
    
    @torch.no_grad()
    def dequantize(self, p_name, state_name):
        root_enable = self._sanity_check(p_name, state_name)
        if not root_enable:
            return None
        state = self.state[p_name][state_name]
        metadata = state['metadata']
        enable = metadata is not None
        if enable:
            x = vectorwise_dequant(state['qx'], **metadata)
            state['metadata'] = state['qx'] = None
            return x


quantizer = Quantizer() # global quantizer


def q_and_dq_parameters(model):
    if not quantizer.is_inited():
        return 

    state_name = 'param'
    for p_name, p in model.named_parameters():
        x = p.data
        quantizer.quantize(p_name, state_name, x)

        # optional: do some statistics here

        p_hat = quantizer.dequantize(p_name, state_name)
        if p_hat is not None:
            p.data = p_hat


def q_and_dq_optimizer_states(model, optimizer, optimizer_name, state_name, optimizer_state_name):
    assert quantizer.is_inited()

    for p_name, p in model.named_parameters():
        state = optimizer.state[p]

        if optimizer_state_name not in state:
            continue

        x = state[optimizer_state_name]

        def adam(exp_avg, exp_avg_sq, eps=1e-8):
            return exp_avg_sq.add(eps).rsqrt() * exp_avg
        def sgd(exp_avg):
            return exp_avg
        if quantizer._adaptive_quant_criterion(p_name, state_name, x):
            op = None
            if optimizer_name in ['adam', 'adamw']:
                if state_name == 'exp_avg':
                    op = functools.partial(adam, exp_avg_sq=state['exp_avg_sq'])
                elif state_name == 'exp_avg_sq':
                    op = functools.partial(adam, exp_avg=state['exp_avg'])
            elif optimizer_name == 'sgd':
                op = sgd
            quantizer._adaptive_quant_config_update(p_name, state_name, x, op)

        quantizer.quantize(p_name, state_name, x)

        # optional: do some statistics here

        state_hat = quantizer.dequantize(p_name, state_name)
        if state_hat is not None:
            state[optimizer_state_name] = state_hat


def q_and_dq_gradients_hook(
    process_group: dist.ProcessGroup, bucket: dist.GradBucket
) -> torch.futures.Future[torch.Tensor]:
    assert quantizer.is_inited()

    group_to_use = process_group if process_group is not None else dist.group.WORLD

    # tensor = bucket.buffer()
    # print(tensor.shape, tensor.device, tensor.dtype, tensor.stride(), tensor[:5])
    parameters = bucket.parameters()
    gradients = bucket.gradients()
    state_name = 'grad'
    flattened_g = []
    for p, g in zip(parameters, gradients):
        p_name = p.p_name
        quantizer.quantize(p_name, state_name, g)

        # optional: do some statistics here

        g_hat = quantizer.dequantize(p_name, state_name)
        g_hat = g_hat.flatten() if g_hat is not None else g.flatten()
        flattened_g.append(g_hat)
    flattened_g = torch.cat(flattened_g, dim=0)
    flattened_g.div_(group_to_use.size())
    # print(flattened_g.shape, flattened_g.device, flattened_g.dtype, flattened_g.stride(), flattened_g[:5])

    return (
        dist.all_reduce(flattened_g, group=group_to_use, async_op=True)
        .get_future()
        .then(lambda fut: fut.value()[0])
    )


@torch.no_grad()
def dist_quant_consistency_checker(model, optimizer, config):
    assert quantizer.is_inited()

    if config.QUANT.P.ENABLE:
        state_name = 'param'
        for p_name, p in model.named_parameters():
            enable, _ = quantizer.qconfig_fn(p_name, state_name, p.data)
            if not enable:
                continue
            group_to_use = dist.group.WORLD
            tensor = p.data
            tensor_list = [torch.zeros_like(tensor) for _ in range(group_to_use.size())]
            handle = dist.all_gather(tensor_list, tensor, async_op=True)
            # Wait ensures the operation is enqueued, but not necessarily complete.
            handle.wait()
            if not (tensor_list[0] == tensor).all().item():
                print(f"{p_name}-{state_name}: tensor[0:5]={tensor.flatten()[:5]} @rank {dist.get_rank()}, tensor[0:5]={tensor_list[0].flatten()[:5]} @rank 0.")
            else:
                print(f"{p_name}-{state_name} across devices are consistent")
    if config.QUANT.M.ENABLE:
        state_name = 'exp_avg'
        for p_name, p in model.named_parameters():
            state = optimizer.state[p]
            if state_name not in state:
                continue
            x = state[state_name]
            enable, _ = quantizer.qconfig_fn(p_name, state_name, x)
            if not enable:
                continue
            group_to_use = dist.group.WORLD
            tensor = x
            tensor_list = [torch.zeros_like(tensor) for _ in range(group_to_use.size())]
            handle = dist.all_gather(tensor_list, tensor, async_op=True)
            # Wait ensures the operation is enqueued, but not necessarily complete.
            handle.wait()
            if not (tensor_list[0] == tensor).all().item():
                print(f"{p_name}-{state_name}: tensor[0:5]={tensor.flatten()[:5]} @rank {dist.get_rank()}, tensor[0:5]={tensor_list[0].flatten()[:5]} @rank 0.")
            else:
                print(f"{p_name}-{state_name} across devices are consistent")
        

@torch.no_grad()
def dist_sample_diversity_checker(batch):
    group_to_use = dist.group.WORLD
    tensor = batch
    tensor_list = [torch.zeros_like(tensor) for _ in range(group_to_use.size())]
    handle = dist.all_gather(tensor_list, tensor, async_op=True)
    # Wait ensures the operation is enqueued, but not necessarily complete.
    handle.wait()
    if (tensor_list[-1] == tensor).all().item():
        print(f"Warning: samples across devices have no diversity @rank {dist.get_rank()}.")
    else:
        print(f"Sample diversity guaranteed.")


def _get_subconfig(config, state_name):
    if state_name == 'param':
        return config.QUANT.P
    elif state_name == 'grad':
        return config.QUANT.G
    elif state_name == 'exp_avg':
        return config.QUANT.M
    elif state_name == 'exp_avg_sq':
        return config.QUANT.SQM
    else:
        raise NotImplementedError