import heapq
import tensorflow as tf
import numpy as np

from mayo.override.base import OverriderBase, Parameter


class HuffmanEncoder(OverriderBase):
    def __init__(self, session, should_update=True):
        super().__init__(session, should_update)
        self.encoder = HuffmanCoding()

    def _apply(self, value):
        return value

    def _info(self):
        value = self.session.run(self.before)
        self.encoder.encode(value)
        equivalent_width = self.encoder.equ_bitwdith
        numel = int(np.size(value))
        total_bits = equivalent_width * numel
        return self._info_tuple(
            equivalent_width=equivalent_width,
            total_bits=total_bits,
            count_=numel)

    @classmethod
    def finalize_info(cls, table):
        bitwidth = table.get_column('equivalent_width')
        count = table.get_column('count_')
        avg_bitwidth = sum(d * c for d, c in zip(bitwidth, count)) / sum(count)
        footer = ['overall: ', avg_bitwidth, None, None]
        table.add_row(footer)
        return footer


class HuffmanTreeNode(object):
    def __init__(self, value, freq):
        self.value = value
        self.freq = freq
        self.left = None
        self.right = None

    def children(self):
        return (self.left, self.right)

    def __lt__(self, other):
        return self.freq < other.freq

    def __eq__(self, other):
        if other is None:
            return False
        if not isinstance(other, HeapNode):
            return False
        return self.freq == other.freq


class HuffmanCoding(object):
    """
    References;
        https://github.com/bhrigu123/huffman-coding/blob/master/huffman.py
        http://bhrigu.me/blog/2017/01/17/huffman-coding-python-implementation/
    """
    def __init__(self):
        self.huffman_tree = []
        self.codes = {}
        self.equ_bitwdith = -1
        self.sta_bitwdith = -1
        self.reverse_mapping = {}
        self.frequncy = {}

    def encode(self, value):
        frequency = self._make_frequency_dict(value)
        self._make_tree(frequency)
        self.frequncy = frequency
        self._merge_nodes()
        self._make_codes()
        self.equ_bitwdith = self.equivalent_bitwidth()
        self.sta_bitwdith = self.static_bitwidth()

    def _make_frequency_dict(self, x):
        '''''[find freqs of unique values]
        Returns:
            [dict] -- [(unique values, frequencies)]
        '''
        unique, counts = np.unique(x, return_counts=True)
        unique_counts = {}
        for k, v in zip(unique, counts):
            unique_counts[k] = v
        return unique_counts

    def _make_tree(self, freq):
        for k, v in freq.items():
            node = HuffmanTreeNode(k, v)
            heapq.heappush(self.huffman_tree, node)

    def _merge_nodes(self):
        while(len(self.huffman_tree)>1):
            node1 = heapq.heappop(self.huffman_tree)
            node2 = heapq.heappop(self.huffman_tree)
            merged = HuffmanTreeNode(None, node1.freq + node2.freq)
            merged.left = node1
            merged.right = node2
            heapq.heappush(self.huffman_tree, merged)

    def _make_codes(self):
        root = heapq.heappop(self.huffman_tree)
        current_code = ""
        self._make_codes_helper(root, current_code)

    def _make_codes_helper(self, root, current_code):
        if(root == None):
            return
        if(root.value != None):
            self.codes[root.value] = current_code
            self.reverse_mapping[current_code] = root.value
            return
        self._make_codes_helper(root.left, current_code + "0")
        self._make_codes_helper(root.right, current_code + "1")

    def static_bitwidth(self):
        total_bitwidth = 0
        for k, v in self.codes.items():
            total_bitwidth += len(v)
        return total_bitwidth / len(self.codes)

    def equivalent_bitwidth(self):
        num_ele = 0
        total_bitwidth = 0
        for k, v in self.codes.items():
            num_ele += self.frequncy[k]
            total_bitwidth += len(v) * self.frequncy[k]
        return total_bitwidth / num_ele

