import heapq
import numpy as np
import tensorflow as tf
from sklearn.mixture import GaussianMixture

from mayo.util import memoize_property, object_from_params, Percent
from mayo.override import util
from mayo.override.base import OverriderBase, Parameter
from mayo.override.quantize.base import QuantizedParameter
from mayo.log import log


np.random.seed(seed=66666)


class Recentralizer(OverriderBase):
    """Recentralizes the distribution of pruned weights.  """
    positives = Parameter('positives', None, None, 'bool')
    negatives = Parameter('negatives', None, None, 'bool')
    mean = QuantizedParameter('mean', None, None, 'float')
    std = QuantizedParameter('std', None, None, 'float')
    scale = QuantizedParameter('scale', None, None, 'float', trainable=True)
    offset = QuantizedParameter('offset', None, None, 'float', trainable=True)

    # modes: -1: auto, 0: disabled, 1: single, 2: dual
    mean_mode = Parameter('mean_mode', 2, [], 'int')
    std_mode = Parameter('std_mode', 0, [], 'int')
    scale_mode = Parameter('scale_mode', 0, [], 'int')
    offset_mode = Parameter('offset_mode', 0, [], 'int')

    def __init__(
            self, session, quantizer, parameter_quantizers=None,
            modes=None, regularization=0.0, epsilon=0.0,
            should_update=True, enable=True, huffman=True):
        super().__init__(session, should_update, enable)
        if quantizer:
            cls, params = object_from_params(quantizer)
            self.quantizer = cls(session, **params)
        else:
            self.quantizer = None
        self._setup_modes(modes)
        self.regularization = float(regularization)
        self.epsilon = epsilon
        quantizers = {}
        for key, each in (parameter_quantizers or {}).items():
            cls, params = object_from_params(each)
            quantizers[key] = cls(session, **params)
        self._parameter_quantizers = quantizers
        self._mixture = None
        self.huffman = huffman

    def _setup_modes(self, modes):
        modes = modes or {}
        self.channel_modes = modes.get('channel', [])
        self.mean_mode = modes.get('mean', 2)
        self.std_mode = modes.get('std', 0)
        self.offset_mode = modes.get('offset', 0)
        self.scale_mode = modes.get('scale', 0)

    @memoize_property
    def nonzeros(self):
        return util.logical_or(self.positives, self.negatives)

    def assign_parameters(self):
        super().assign_parameters()
        if self.quantizer:
            self.quantizer.assign_parameters()
        for quantizer in self._parameter_quantizers.values():
            quantizer.assign_parameters()

    @staticmethod
    def _value_select(mode, value, pdisable, ndisable):
        def select(x0, x1, x2):
            is1, is2 = tf.equal(mode, 1), tf.equal(mode, 2)
            x1 = tf.cond(is1, lambda: x1, lambda: x0)
            x2 = tf.cond(is2, lambda: x2, lambda: x0)
            return tf.cond(is1, lambda: x1, lambda: x2)
        total, positive, negative = tf.unstack(value, axis=0)
        positive = select(pdisable, total, positive)
        negative = select(ndisable, total, negative)
        return positive, negative

    def _recentralize(self, value, pmask, nmask):
        pvalue = pmask * ((value - self.pmean) / (self.pstd + self.epsilon))
        nvalue = nmask * ((value - self.nmean) / (self.nstd + self.epsilon))
        if getattr(self.quantizer, 'asymmetry'):
            nvalue = -nvalue
        return pvalue, nvalue

    def _quantize(self, positives, negatives):
        value = positives + negatives
        quantizer = self.quantizer
        if quantizer is None:
            return value
        scope = '{}/{}'.format(self._scope, self.__class__.__name__)
        return quantizer.apply(self.node, scope, self._original_getter, value)

    def _derecentralize(self, value, pmask, nmask):
        pvalue = value * self.pscale + self.poffset
        asymmetry = getattr(self.quantizer, 'asymmetry')
        nvalue = (-value if asymmetry else value) * self.nscale + self.noffset
        pvalue = pmask * ((self.pstd + self.epsilon) * pvalue + self.pmean)
        nvalue = nmask * ((self.nstd + self.epsilon) * nvalue + self.nmean)
        return pvalue + nvalue

    def _apply(self, value):
        # dynamic parameter configuration
        mask_parameter = {
            'initial': tf.ones_initializer(tf.bool),
            'shape': value.shape,
        }
        def channel_wise(name, init):
            if name in self.channel_modes:
                shape = [3] + [1] * (value.shape.ndims - 1) + [value.shape[-1]]
            else:
                shape = [3]
            init_map = {
                0: tf.zeros_initializer(),
                1: tf.ones_initializer(),
            }
            return {'shape': shape, 'initial': init_map[init]}
        self._parameter_config = {
            'positives': mask_parameter,
            'negatives': mask_parameter,
            'mean': channel_wise('mean', 0),
            'std': channel_wise('std', 1),
            'offset': channel_wise('offset', 0),
            'scale': channel_wise('scale', 1),
        }
        # initialize variables
        self.pmean, self.nmean = self._value_select(
            self.mean_mode, self.mean, 0.0, 0.0)
        self.pstd, self.nstd = self._value_select(
            self.std_mode, self.std, 1.0, 1.0)
        self.pscale, self.nscale = self._value_select(
            self.scale_mode, self.scale, 1.0, 1.0)
        self.poffset, self.noffset = self._value_select(
            self.offset_mode, self.offset, 0.0, 0.0)
        # forward pass
        pmask = util.cast(self.positives, float)
        nmask = util.cast(self.negatives, float)
        zmask = util.cast(self.nonzeros, float)
        positives, negatives = self._recentralized = \
            self._recentralize(value, pmask, nmask)
        quantized = self._quantized = \
            zmask * self._quantize(positives, negatives)
        quantized_value = zmask * self._derecentralize(quantized, pmask, nmask)
        # regularize
        self._quantization_loss_regularizer(value, quantized_value)
        return quantized_value

    def _quantization_loss_regularizer(self, value, quantized_value):
        if self.regularization <= 0.0:
            return
        loss = tf.reduce_sum(tf.abs(value - quantized_value))
        loss *= self.regularization
        loss_name = tf.GraphKeys.REGULARIZATION_LOSSES
        tf.add_to_collection(loss_name, loss)

    def _update_values(self, mean, std):
        # assign mean and std
        mean = np.stack(mean)
        std = np.stack(std)
        self.mean = mean
        self.std = std
        self.offset = np.zeros((3, ))
        self.scale = np.ones((3, ))
        # update internal quantizer
        if self.quantizer:
            self.quantizer.update()
        for each in self._parameter_quantizers.values():
            each.update()

    def _update(self):
        # update positives mask and mean values
        value = self.session.run(self.before)
        channel_axes = list(range(value.ndim - 1))
        axes = channel_axes if 'moments' in self.channel_modes else None
        # total moments excluding zeros
        tmean, var = util.moments(value[util.where(value != 0)], axes=axes)
        tstd = np.sqrt(var)
        # find positive moments, note we are using 0 instead of mean
        center = 0  # 0 or tmean
        positives = value > center
        self.positives = positives
        self.negatives = value < center
        pmean, pvar = util.moments(value[util.where(positives)], axes=axes)
        pstd = np.sqrt(pvar)
        # negative moments
        negatives = util.logical_and(util.logical_not(positives), value != 0)
        nmean, nvar = util.moments(value[util.where(negatives)], axes=axes)
        nstd = np.sqrt(nvar)
        # assign values
        self._update_values([tmean, pmean, nmean], [tstd, pstd, nstd])

    def _info(self):
        if self.quantizer:
            info = self.quantizer.info()._asdict()
            info.pop('name')
        else:
            info = {}
        to_run = {
            'positives': self.positives,
            'negatives': self.negatives,
            'zeros': self.nonzeros,
            'quantized': self.after,
            'modes': [
                self.mean_mode, self.std_mode,
                self.scale_mode, self.offset_mode],
            # 'scales': [self.pscale, self.nscale],
            # 'offsets': [self.poffset, self.noffset],
        }
        results = self.session.run(to_run)
        for k in ['positives', 'negatives']:
            results[k] = Percent(np.sum(results[k]) / results[k].size)
        nonzeros = results['zeros']
        results['zeros'] = Percent(1 - np.sum(nonzeros) / nonzeros.size)
        quantized = results.pop('quantized')
        if self.huffman:
            huffman_encoder = HuffmanCoding()
            huffman_encoder.encode(quantized)
            results['huffman_bitwidth'] = huffman_encoder.equ_bitwdith
            results['params_'] = int(np.size(quantized))
        info.update(results)
        return self._info_tuple(**info)

    def estimate(self, layer_info, info):
        """ Override this method to modify layer estimation statistics.  """
        return layer_info


class BimodalGaussian(GaussianMixture):
    def __init__(self, name, data, overflow_rate=0.0):
        data, bound = self._overflow(data, overflow_rate)
        pmean, _1, nmean, _2, phi = self._find_initial(data)
        means = np.array([[pmean], [nmean]])
        weights = np.array([phi, 1 - phi])
        super().__init__(2, means_init=means, weights_init=weights)
        self.name = name
        self.bound = bound
        self.data = data
        self._fit()

    @staticmethod
    def _overflow(data, orate):
        abs_data = np.abs(data)
        if orate <= 0:
            return data, np.max(abs_data)
        magnitudes = np.sort(abs_data)
        index = int((1 - orate) * data.size)
        max_value = magnitudes[min(max(0, index), data.size - 1)]
        return data[abs_data < max_value], max_value
        # return np.where(abs_data < max_value, data, np.sign(data) * max_value)

    @staticmethod
    def _find_initial(x):
        p = x > 0
        pv = x[np.where(p)]
        pmean, pstd = np.mean(pv), np.std(pv)
        n = np.logical_and(np.logical_not(p), x != 0)
        nv = x[np.where(n)]
        nmean, nstd = np.mean(nv), np.std(nv)
        phi = np.size(pv) / np.size(x)
        return pmean, pstd, nmean, nstd, phi

    def _fit(self):
        self.tmean = np.mean(self.data)
        self.tstd = np.sqrt(np.var(self.data))
        self.fit(self.data.reshape(-1, 1))
        if not self.converged_:
            raise ValueError('Unable to find MLE for weight distribution.')
        means, vars, weights = self.means_, self.covariances_, self.weights_
        mean1, mean2 = means[:, 0]
        std1, std2 = np.sqrt(vars[:, 0, 0])
        phi = weights[0]
        if mean1 < mean2:
            mean1, mean2 = mean2, mean1
            std1, std2 = std2, std1
            phi = 1 - phi
            log.debug('Components flipped.')
            # raise ValueError('Components flipped.')
        self.pmean, self.nmean = mean1, mean2
        self.pstd, self.nstd = std1, std2
        self.phi = phi

    def pdf1(self, x):
        pdf1 = super().predict_proba(x.reshape(-1, 1))
        return pdf1[:, 0].reshape(x.shape)

    def predict1(self, x, sampling=False):
        if sampling:
            return np.random.uniform(size=x.shape) < self.pdf1(x)
        return self.pdf1(x) >= 0.5

    def pdf(self, x):
        return np.exp(self.score_samples(x.reshape(-1, 1)))

    def plot(self):
        import platform
        macos = platform.system() == 'Darwin'
        if not macos:
            import matplotlib
            matplotlib.use('Agg')
        import matplotlib.pyplot as plt
        from scipy.stats import norm
        data = self.data
        plt.hist(data.flatten(), bins=500, density=True, color='c', alpha=0.5)
        q = np.linspace(np.min(data), np.max(data), 10000)
        plt.plot(q, self.pdf(q), '-', label='pdf')
        p1 = norm.pdf(q, self.pmean, self.pstd)
        p2 = norm.pdf(q, self.nmean, self.nstd)
        plt.plot(q, self.phi * p1, '--', label='in1')
        plt.plot(q, (1 - self.phi) * p2, '--', label='in2')
        plt.savefig(f"plots/{self.name.replace('/', '_')}.pdf")
        if macos:
            plt.show()
        plt.gcf().clear()

    @property
    def wasserstein(self):
        # scale and shift mixture distribution for 2-wasserstein criteria
        pmean = (self.pmean - self.tmean) / self.tstd
        nmean = (self.nmean - self.tmean) / self.tstd
        pstd, nstd = self.pstd / self.tstd, self.nstd / self.tstd
        # 2-wasserstein distance
        # FIXME only works well if self.phi is close to 0.5
        wmean = (pmean - nmean) ** 2
        wvar = pstd ** 2 + nstd ** 2 - 2 * pstd * nstd
        return wmean + wvar

    @property
    def crossover(self):
        # computes the probability of cross-over
        from scipy.stats import norm
        resolution = 1000
        x = np.linspace(self.nmean, self.pmean, resolution)
        i = np.argmin(np.abs(self.pdf1(x) - 0.5))
        xco = x[i]
        nco = norm.cdf(xco, self.nmean, self.nstd)
        pco = norm.cdf(xco, self.pmean, self.pstd)
        return self.phi * pco + (1 - self.phi) * (1 - nco)


class MLERecentralizer(Recentralizer):
    """
    Assumes bimodal Gaussian distribution of weights,
    Carries out w' = f^-1 . q . f (w), where f normalizes both Gaussians,
    and q quantizes the values.

    Arguments::
        quantizer (mayo.override.quantizer.QuantizerBase):
            the quantizer used.
        parameter_quantizers (dict):
            a mapping of quantizers for parameters
            `mean`, `std`, `scale` and `offset`.
        modes (dict):
            a mapping of modes for `mean`, `std`, `scale` and `offset`,
            where the mode values can take 0, 1 and 2 representing
            the number of parameters used, and `auto` for auto selecting
            the appropriate number.
        regularization (float):
            the regularization weight on quantization error loss.
        epsilon (float):
            [not tested] a small constant to prevent division by zero.
        sample (bool):
            determines whether to sample the mixture
            or simply argmax probabilities.
        wasserstein_separation (float):
            the smallest wasserstein distance required to treat distribution
            as bimodal.  this option overrides `crossover_separation`.
        crossover_separation (float):
            the largest probability of bimodal Gaussian crossover allowed
            to treat distribution as bimodal.
        discard_overflow_rate (float):
            the percentage of large magnitude values to discard,
            useful when heavy-tail distributions fit poorly.
            by default, we auto discard 1 / (2 ** quantizer_width + 1) values.
        bypass_unseparable (bool):
            if unimodal, we bypass it with modes 0000.
        transfer_bit_to_quantizer_when_unseparable (bool):
            if unimodal, we transfer the unused recentralization bit
            to the quantizer.
        plot (bool):
            plot weight histogram and the maximum-likelihood distribution.
    """
    def __init__(
            self, session, quantizer, parameter_quantizers=None,
            modes=None, regularization=0.0, epsilon=0.0, sample=True,
            wasserstein_separation=None, crossover_separation=None,
            discard_overflow_rate=-1, bypass_unseparable=True,
            transfer_bit_to_quantizer_when_unseparable=True,
            plot=False, should_update=True, enable=True, huffman=True):
        super().__init__(
            session, quantizer, parameter_quantizers, modes,
            regularization, epsilon, should_update, enable, huffman)
        self.sample = sample
        self.wsep = wasserstein_separation
        self.csep = crossover_separation
        if self.wsep is None and self.csep is None:
            self.wsep = 0
        self.bypass = bypass_unseparable
        self.transfer = transfer_bit_to_quantizer_when_unseparable
        self.overflow_rate = discard_overflow_rate
        self._quantizer_width = quantizer.get('width')
        self.plot = plot

    def _setup_modes(self, modes):
        modes = modes or {}
        if modes.get('channel'):
            raise NotImplementedError(
                'Channel-wise operations on {} is currently not supported.'
                .format(self.__class__.__name__))
        self.channel_modes = []
        def get(name, default):
            value = modes.get(name)
            if value is None:
                value = default
            if value == 'auto':
                return -1
            return value
        self.mean_mode = get('mean', 'auto')
        self.std_mode = get('std', 'auto')
        self.offset_mode = get('offset', 0)
        self.scale_mode = get('scale', 0)

    def _assign_modes(self, mixture):
        if self.wsep is not None:
            dist = mixture.wasserstein
            separate = dist > self.wsep
            log.debug(
                'Wasserstein metric {}, threshold {}, separate = {}.'
                .format(dist, self.wsep, separate))
        elif self.csep is not None:
            dist = mixture.crossover
            separate = dist < self.csep
            log.debug(
                'Crossover metric {}, threshold {}, separate = {}.'
                .format(dist, self.csep, separate))
        else:
            dist = None
            separate = True
        num = 2 if separate else (0 if self.bypass else 1)
        modes = [
            self.mean_mode, self.std_mode, self.scale_mode, self.offset_mode]
        mean_mode, std_mode, scale_mode, offset_mode = self.session.run(modes)
        update = lambda mode: num if mode == -1 else mode
        self.mean_mode = mean_mode = update(mean_mode)
        self.std_mode = std_mode = update(std_mode)
        self.scale_mode = scale_mode = update(scale_mode)
        self.offset_mode = offset_mode = update(offset_mode)
        log.debug(
            '{} uses modes {{mean: {}, std: {}, scale: {}, offset: {}}}.'
            .format(self, mean_mode, std_mode, scale_mode, offset_mode))
        return mean_mode, std_mode, scale_mode, offset_mode

    def _update(self):
        # update positives mask and mean values
        value = self.session.run(self.before)
        nonzeros = value != 0
        nonzero_value = value[util.where(nonzeros)]
        # overall mean and std
        tmean, var = util.moments(nonzero_value)
        tstd = np.sqrt(var)
        # bimodal Gaussian distribution of non-zero values
        orate = self.overflow_rate
        if orate < 0:
            orate = 1 / (2 ** self._quantizer_width + 1)
        mixture = self._mixture = BimodalGaussian(
            self.name, nonzero_value, orate)
        pmean, nmean = mixture.pmean, mixture.nmean
        pstd, nstd = mixture.pstd, mixture.nstd
        phi = mixture.phi
        log.debug(
            'Found bimodal Gaussian mixture with '
            '{0} * N({0}, {0}) + {0} * N({0}, {0}).'.format('{:.3g}')
            .format(phi, pmean, pstd, 1 - phi, nmean, nstd))
        modes = self._assign_modes(mixture)
        if all(mode in [0, 1] for mode in modes):
            transfer_bit = 1 if self.transfer else 0
            self.quantizer.width = self._quantizer_width + transfer_bit
            if transfer_bit:
                log.debug(
                    'Transferred unused recentralization bit to quantizer.')
        # sampling to divide them into two groups
        positives = mixture.predict1(value, self.sample)
        positives = np.logical_and(positives, nonzeros)
        negatives = np.logical_and(np.logical_not(positives), nonzeros)
        self.positives = positives
        self.negatives = negatives
        self._update_values([tmean, pmean, nmean], [tstd, pstd, nstd])
        if self.plot:
            mixture.plot()

    def _info(self):
        info = super()._info()._asdict()
        info.pop('name')
        if self._mixture:
            info['wdist'] = self._mixture.wasserstein
        return self._info_tuple(**info)

    @classmethod
    def finalize_info(cls, table):
        table.footer_mean('huffman_bitwidth', 'params_')


class HuffmanTreeNode(object):
    def __init__(self, value, freq):
        self.value = value
        self.freq = freq
        self.left = None
        self.right = None

    def children(self):
        return (self.left, self.right)

    def __lt__(self, other):
        return self.freq < other.freq

    def __eq__(self, other):
        if other is None:
            return False
        if not isinstance(other, HeapNode):
            return False
        return self.freq == other.freq


class HuffmanCoding(object):
    """
    References;
        https://github.com/bhrigu123/huffman-coding/blob/master/huffman.py
        http://bhrigu.me/blog/2017/01/17/huffman-coding-python-implementation/
    """
    def __init__(self):
        self.huffman_tree = []
        self.codes = {}
        self.equ_bitwdith = -1
        self.sta_bitwdith = -1
        self.reverse_mapping = {}
        self.frequncy = {}

    def encode(self, value):
        frequency = self._make_frequency_dict(value)
        self._make_tree(frequency)
        self.frequncy = frequency
        self._merge_nodes()
        self._make_codes()
        self.equ_bitwdith = self.equivalent_bitwidth()
        self.sta_bitwdith = self.static_bitwidth()

    def _make_frequency_dict(self, x):
        '''''[find freqs of unique values]
        Returns:
            [dict] -- [(unique values, frequencies)]
        '''
        unique, counts = np.unique(x, return_counts=True)
        unique_counts = {}
        for k, v in zip(unique, counts):
            unique_counts[k] = v
        return unique_counts

    def _make_tree(self, freq):
        for k, v in freq.items():
            node = HuffmanTreeNode(k, v)
            heapq.heappush(self.huffman_tree, node)

    def _merge_nodes(self):
        while(len(self.huffman_tree)>1):
            node1 = heapq.heappop(self.huffman_tree)
            node2 = heapq.heappop(self.huffman_tree)
            merged = HuffmanTreeNode(None, node1.freq + node2.freq)
            merged.left = node1
            merged.right = node2
            heapq.heappush(self.huffman_tree, merged)

    def _make_codes(self):
        root = heapq.heappop(self.huffman_tree)
        current_code = ""
        self._make_codes_helper(root, current_code)

    def _make_codes_helper(self, root, current_code):
        if(root == None):
            return
        if(root.value != None):
            self.codes[root.value] = current_code
            self.reverse_mapping[current_code] = root.value
            return
        self._make_codes_helper(root.left, current_code + "0")
        self._make_codes_helper(root.right, current_code + "1")

    def static_bitwidth(self):
        total_bitwidth = 0
        for k, v in self.codes.items():
            total_bitwidth += len(v)
        return total_bitwidth / len(self.codes)

    def equivalent_bitwidth(self):
        num_ele = 0
        total_bitwidth = 0
        for k, v in self.codes.items():
            num_ele += self.frequncy[k]
            total_bitwidth += len(v) * self.frequncy[k]
        return total_bitwidth / num_ele
