from typing import Sequence, Dict, Union, Set
import torch
import numpy as np
from lib.utils import tensor2npy


class Evaluate(object):
    def __init__(self, metrics, topks, device='cuda:0', uni_neg_sample_num=-1):
        self.metrics = metrics
        self.topks = np.array(topks, dtype=np.int32)
        self.device = device
        self.uni_neg_sample_num = uni_neg_sample_num
        self.metric_names = [
            f"{metric}@{topk}" for metric in self.metrics for topk in self.topks
        ]

    def dict2list(self, uid_list, u2i):
        item_list = [torch.LongTensor(list(u2i[uid])) for uid in uid_list]
        u = torch.cat(
            [torch.full((len(hist_iid),), i, dtype=torch.long) for i, hist_iid in enumerate(item_list)]
        ).to(self.device)
        i = torch.cat(item_list).to(self.device)
        return u, i

    @torch.no_grad()
    def evaluate(self, uid_list: Union[Sequence, Set], rating_mat: torch.Tensor, except_u2i: Dict, test_u2i: Dict):
        test_u, test_i = self.dict2list(uid_list, test_u2i)

        if self.uni_neg_sample_num > 0:
            item_count = rating_mat.shape[-1]
            itemset = set(range(item_count))
            total_num = self.uni_neg_sample_num
            neg_sample_u2i= {}
            for uid in uid_list:
                if len(except_u2i[uid]) + len(test_u2i[uid]) + self.uni_neg_sample_num > item_count:
                    neg_sample_u2i[uid] = np.array(list(itemset - (set(except_u2i[uid]) | set(test_u2i[uid]))))
                    continue
                value_ids = np.zeros(total_num, dtype=np.int64)
                check_list = np.arange(total_num)
                while len(check_list) > 0:
                    value_ids[check_list] = np.random.randint(0, item_count, len(check_list))
                    check_list = np.array([
                        i for i, v in zip(check_list, value_ids[check_list])
                        if v in except_u2i[uid] or v in test_u2i[uid]
                    ])
                neg_sample_u2i[uid] = value_ids

            new_mat = torch.full_like(rating_mat, -np.inf)
            neg_sample_u, neg_sample_i = self.dict2list(uid_list, neg_sample_u2i)
            new_mat[test_u, test_i] = rating_mat[test_u, test_i]
            new_mat[neg_sample_u, neg_sample_i] = rating_mat[neg_sample_u, neg_sample_i]
            rating_mat = new_mat
        else:
            except_u, except_i = self.dict2list(uid_list, except_u2i)
            rating_mat[except_u, except_i] = -np.inf

        _, topk_idx = torch.topk(rating_mat, max(self.topks), dim=-1)  # n_users x k
        pos_matrix = torch.zeros_like(rating_mat, dtype=torch.int)
        pos_matrix[test_u, test_i] = 1
        pos_len_list = pos_matrix.sum(dim=1, keepdim=True).squeeze(-1)
        pos_idx = torch.gather(pos_matrix, dim=1, index=topk_idx)
        pos_len_list = tensor2npy(pos_len_list)
        pos_idx = tensor2npy(pos_idx)
        res_list = [
            self.__getattribute__(metric)(pos_len_list, pos_idx)[:, self.topks - 1] for metric in self.metrics
        ]
        return np.hstack(res_list)

    def ndcg(self, pos_len, pos_index):
        len_rank = np.full_like(pos_len, pos_index.shape[1])
        idcg_len = np.where(pos_len > len_rank, len_rank, pos_len)

        iranks = np.zeros_like(pos_index, dtype=np.float64)
        iranks[:, :] = np.arange(1, pos_index.shape[1] + 1)
        idcg = np.cumsum(1.0 / np.log2(iranks + 1), axis=1)
        for row, idx in enumerate(idcg_len):
            idcg[row, idx:] = idcg[row, idx - 1]  # idx是用户对应测试集列表长度，idx不能为0，为0，计算的idcg就是其他用户了

        ranks = np.zeros_like(pos_index, dtype=np.float64)
        ranks[:, :] = np.arange(1, pos_index.shape[1] + 1)
        dcg = 1.0 / np.log2(ranks + 1)
        dcg = np.cumsum(np.where(pos_index, dcg, 0), axis=1)

        result = dcg / idcg
        return result

    def recall(self, pos_len, pos_index):
        return np.cumsum(pos_index, axis=1) / pos_len.reshape(-1, 1)

    def precision(self, pos_len, pos_index):
        return pos_index.cumsum(axis=1) / np.arange(1, pos_index.shape[1] + 1)

    def map(self, pos_len, pos_index):
        pre = pos_index.cumsum(axis=1) / np.arange(1, pos_index.shape[1] + 1)
        sum_pre = np.cumsum(pre * pos_index.astype(np.float64), axis=1)
        len_rank = np.full_like(pos_len, pos_index.shape[1])
        actual_len = np.where(pos_len > len_rank, len_rank, pos_len)
        result = np.zeros_like(pos_index, dtype=np.float64)
        for row, lens in enumerate(actual_len):
            ranges = np.arange(1, pos_index.shape[1] + 1)
            ranges[lens:] = ranges[lens - 1]
            result[row] = sum_pre[row] / ranges
        return result

    def mrr(self, pos_len, pos_index):
        idxs = pos_index.argmax(axis=1)
        result = np.zeros_like(pos_index, dtype=np.float64)
        for row, idx in enumerate(idxs):
            if pos_index[row, idx] > 0:
                result[row, idx:] = 1 / (idx + 1)
            else:
                result[row, idx:] = 0
        return result

    def hr(self, pos_len, pos_index):
        result = np.cumsum(pos_index, axis=1)
        return (result > 0).astype(int)
