import numpy as np
import math
import math
import queue
import torch
from tqdm import tqdm


def init_k_ary_tree(id_list, k, seed=3407):
    if isinstance(id_list, list):
        id_list = np.array(id_list)
    np.random.seed(seed)
    item_num = len(id_list)
    tree_height = int(math.ceil(math.log(item_num, k)) + 0.0000001)
    min_parent_code = int((math.pow(k, tree_height - 1) - 1) / (k - 1))
    max_parent_code = int((math.pow(k, tree_height) - 1) / (k - 1) - 1)
    
    # parent_code_freq_dict = {
    #     parent_code: 0 for parent_code in range(min_parent_code, max_parent_code + 1)
    # }

    # fill parent codes with -1
    parent_codes = np.full((item_num,), -1, dtype=np.int64)   
    parent_code_array = np.arange(min_parent_code, max_parent_code + 1)
    parent_codes[:len(parent_code_array)] = parent_code_array
    parent_code_freq = np.ones_like(parent_code_array)
    for i in tqdm(range(item_num), leave=False, desc='parent codes'):
        if parent_codes[i] != -1:
            continue
        effective_parent_code_array = parent_code_array[parent_code_freq < k]
        random_code = np.random.choice(effective_parent_code_array)
        parent_codes[i] = random_code
        parent_code_freq[random_code- min_parent_code] += 1
        

    max_leaf_code = int((math.pow(k, tree_height + 1) - 1) / (k - 1) - 1)
    min_leaf_code = int((math.pow(k, tree_height) - 1) / (k - 1))
    leaf_codes = np.zeros(item_num, dtype=np.int64)
    leaf_code_freq_cidt = {
        leaf_code: 0 for leaf_code in range(min_leaf_code, max_leaf_code + 1)
    }

    for i, code in tqdm(enumerate(parent_codes), leave=False, desc='leaf codes', total=len(parent_codes)):
        candidate_list = [
            code * k + j + 1
            for j in range(k)
            if leaf_code_freq_cidt[code * k + j + 1] == 0
        ]
        random_leaf_code = np.random.choice(candidate_list)
        leaf_codes[i] = random_leaf_code
        leaf_code_freq_cidt[random_leaf_code] += 1
    assert len(set(leaf_codes)) == item_num
    assert (leaf_codes <= max_leaf_code).all() and (leaf_codes >= min_leaf_code).all()
    argindex = np.argsort(leaf_codes)
    leaf_codes = leaf_codes[argindex]
    item_ids = id_list[argindex]
    return item_ids, leaf_codes


class TreeNode(object):
    def __init__(self, id=0, code=0, isleaf=False):
        self.id = id  # local id
        self.code = code
        self.isleaf = isleaf
        self.corresponding_item_id = None

    def __le__(self, other):
        if self.id == other.id:
            return self.code <= other.code
        return self.id <= other.id

    def __lt__(self, other):
        if self.id == other.id:
            return self.code < other.code
        return self.id < other.id


class KAryTree(object):
    def __init__(self, item_ids, leaf_codes, k, tree_id=0, device="cuda"):
        print("construct tree, leaf node num is {}".format(len(item_ids)))
        self.k = k
        self.tree_id = tree_id
        self.max_layer_id = 0
        self.node_num = 0
        self.layer_node_num_list = []  # node number at each layer
        # self.item_id_leaf_code = None
        # self.leaf_code_item_id = None
        self.node_code_node_id = dict()
        self.node_code_node_id_array = None
        self.item_id_node_ancestor_id = None
        self.node_id_layer_id = None
        self.node_sampled_indicator = None
        self.maximum_assigned_item_num = dict()  # code->maximum assigned item num

        # initialize the pi
        self.item_id_leaf_code = {
            int(id): int(code) for id, code in zip(item_ids, leaf_codes)
        }
        self.leaf_code_item_id = {
            int(code): int(id) for id, code in zip(item_ids, leaf_codes)
        }

        node_list = []
        node_code_set = set()
        for code in leaf_codes:
            cur_code = int(code)
            while cur_code >= 0:
                if cur_code not in node_code_set:
                    node_code_set.add(cur_code)
                    node_list.append(TreeNode(code=cur_code))
                    cur_code = int((cur_code - 1) / k)
                else:
                    break
        node_list = sorted(node_list)
        self.node_id_layer_id = torch.zeros(
            len(node_list), dtype=torch.int64, device=device
        )
        cur_layer, cur_layer_node_num = 0, 0
        self.node_code_node_id_array = torch.full(
            (node_list[-1].code + k,), -1, dtype=torch.int64, device=device
            # (node_list[-1].code + 2,), -1, dtype=torch.int64, device=device
        )
        self.node_sampled_indicator = torch.full(
            (node_list[-1].code + k,), 0, dtype=torch.bool
            # (node_list[-1].code + 2,), 0, dtype=torch.bool
        )
        for i, node in enumerate(node_list):
            node.id = i
            self.node_code_node_id[node.code] = i
            self.node_code_node_id_array[node.code] = i
            self.node_id_layer_id[i] = int(math.log((k - 1) * node.code + 1.0, k)+ 0.0000001)

            if self.leaf_code_item_id.get(node.code) is not None:
                node.corresponding_item_id = self.leaf_code_item_id.get(node.code)
                node.isleaf = True
            if int(math.log((k - 1) * node.code + 1.0, k) + 0.0000001) == cur_layer:
                cur_layer_node_num += 1
            else:
                self.layer_node_num_list.append(cur_layer_node_num)
                cur_layer_node_num = 1
                cur_layer += 1
        self.layer_node_num_list.append(cur_layer_node_num)
        self.max_layer_id = int(math.log((k - 1) * leaf_codes[0] + 1, k) + 0.0000001)
        assert self.max_layer_id == len(self.layer_node_num_list) - 1
        assert sum(self.layer_node_num_list) == len(node_list)
        assert len(self.node_code_node_id) == len(node_list)
        assert self.max_layer_id == int(math.ceil(math.log(len(leaf_codes), k)) + 0.0000001)
        # every element in self.layer_node_num_list is the pow of k expect the last element
        for i in range(len(self.layer_node_num_list) - 1):
            assert self.layer_node_num_list[i] == k ** i

        self.node_num = len(node_list)
        self.item_id_node_ancestor_id = torch.zeros(
            (len(item_ids), self.max_layer_id + 1), dtype=torch.int64, device=device
        )
        self.generate_item_id_ancestor_node_id()

        codes = list(self.node_code_node_id.keys())
        codes.sort(reverse=True)  # sort descent
        for code in codes:
            self.maximum_assigned_item_num[code] = 0
            if int(math.log((k-1)*code + 1, k)+ 0.0000001) == self.max_layer_id:  # leaf node
                self.maximum_assigned_item_num[code] = +1
            else:
                for i in range(k):
                    if k * code + i + 1 in self.node_code_node_id:
                        self.maximum_assigned_item_num[code] += self.maximum_assigned_item_num[k * code + i + 1]

        assert self.maximum_assigned_item_num[0] == len(self.item_id_leaf_code)
        print(
            "Tree {},node number is {}, tree height is {}".format(
                tree_id, len(node_list), self.max_layer_id
            )
        )

    def generate_item_id_ancestor_node_id(self):
        # print('hh {},{},{}'.format(len(self.item_id_leaf_code.items()),*self.item_id_node_ancestor_id.shape))
        for item_id, leaf_code in self.item_id_leaf_code.items():
            layer = self.max_layer_id
            code = leaf_code
            while layer >= 0:
                self.item_id_node_ancestor_id[item_id, layer] = self.node_code_node_id[code]
                code = int((code - 1) / self.k)
                layer -= 1
            assert code == 0

    def get_ancestor(self, code, level):
        code_max = int((self.k ** (level + 1) - 1)/ (self.k-1))
        while code >= code_max:
            code = int((code - 1) / self.k)
        return code

    def get_nodes_given_level(self, level):
        code_min = int((self.k ** level - 1) / (self.k - 1))
        code_max = int((self.k ** (level + 1) - 1) / (self.k - 1) - 1)
        res = []
        for code in range(code_min, code_max+1):
            if code in self.node_code_node_id:
                res.append(code)
        return res

    def get_children_given_ancestor_and_level(self, ancestor, level):
        code_min = int((self.k ** level - 1) / (self.k - 1))
        code_max = int((self.k ** (level + 1) - 1) / (self.k - 1) - 1)

        parent_queue = queue.Queue()
        parent_queue.put(ancestor)
        res = []
        while parent_queue.qsize() > 0:
            parent_code = parent_queue.get()
            if parent_code >= code_min and parent_code <= code_max:
                if parent_code in self.node_code_node_id:
                    res.append(parent_code)
            else:
                # parent_queue.put(2 * parent_code + 1)
                # parent_queue.put(2 * parent_code + 2)
                for i in range(self.k):
                    if self.k * parent_code + i + 1 in self.node_code_node_id:
                        parent_queue.put(self.k * parent_code + i + 1)
        return res

    def get_parent_path(self, child, ancestor):
        res = []
        while child > ancestor:
            res.append(child)
            # child = int((child - 1) / 2)
            child = int((child - 1) / self.k)
        return res

    def _ancessors(self, code):
        ancs = []
        while code > 0:
            # code = int((code - 1) / 2)
            child = int((child - 1) / self.k)
            ancs.append(code)
        return ancs

    def get_offspring_number(self, code):
        num = 0
        que = queue.Queue()
        que.put(code)
        while que.qsize() > 0:
            p_code = que.get()
            for i in range(self.k):
                if self.k * p_code + i + 1 in self.node_code_node_id:
                    num += 1
                    que.put(self.k * p_code + i + 1)
        return num


if __name__ == "__main__":
    # c1 = TreeNode(0, 11)
    # c2 = TreeNode(0, 33)
    # c3 = TreeNode(0, 55)
    # c4 = TreeNode(0, 9)
    # c5 = TreeNode(0, 100)
    # c6 = TreeNode(3, 100)
    # c7 = TreeNode(3, 66)

    # ls = [c1, c2, c3, c4, c5, c6, c7]
    # for c in ls:
    #     print(c.id, c.code)
    # print("\n \n")
    # lc = sorted(ls)
    # for c in lc:
    #     print(c.id, c.code)
    # test init_k_ary_tree
    ids=[]
    codes=[]
    kv_file = '../../data/mind/processed_dataset/kv.txt'
    with open(kv_file) as f:
        while True:
            line=f.readline()
            if line:
                id_code=line.split('::')
                ids.append(int(id_code[0]))
                codes.append(int(id_code[1]))
            else:
                break
    ids=np.array(ids,dtype=np.int32)
    codes=np.array(codes,dtype=np.int32)

    # item_ids = [i for i in range(100)]
    item_ids, leaf_codes = init_k_ary_tree(ids, 3)
    print(item_ids)
    print(leaf_codes)
    tree = KAryTree(item_ids, leaf_codes, 3)
    print(tree.get_nodes_given_level(2))
    print(tree.get_children_given_ancestor_and_level(1, 2))
    print(tree.get_parent_path(7, 1))
    print(tree.get_offspring_number(1))
    print(tree.maximum_assigned_item_num)
    print(tree.get_ancestor(7, 2))
    print(tree.get_ancestor(7, 1))

