import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Sequence
from collections import defaultdict
from tqdm import tqdm
import numpy as np
from lib.components import MLP, NormalDistUtil, BernoulliUtil, PosMLP
from torch.utils.data import DataLoader
from lib.utils import tensor2npy
from .BaseModel import BaseModel
from lib.callback import CallbackList, Callback
import pandas as pd
from evaluate import doa_report
import math
from .losses import BPRLoss, MarginLossZeroOne


class DisenCD(BaseModel):
    expose_default_cfg = {
        'EncoderUserHidden': [512],
        'EncoderItemHidden': [512],
        'lambda_main': 1.0,
        'lambda_q': 1.0,
        'align_margin_loss_kwargs': {'margin': 0.7, 'topk': 2, "d1":1, 'margin_lambda': 0.5, 'norm': 1, 'norm_lambda': 1.0, 'start_epoch': 1},
        'sampling_type': 'mws',
        'b_sample_type': 'gumbel_softmax',
        'b_sample_kwargs': {'tau': 1.0, 'hard': True},
        'bernoulli_prior_p': 0.1,
        'bernoulli_prior_auto': False,
        'align_type': 'mse_margin',
        'alpha_user': 0.0,
        'alpha_item': 0.0,
        'gamma_user': 1.0,
        'gamma_item': 1.0,
        'beta_user': 0.0,
        'beta_item': 0.0,
        'g_beta_user': 1.0,
        'g_beta_item': 1.0,
        'disc_scale': 10,
        'pred_dnn_units': [256, 128],
        'pred_dropout_rate': 0.5,
        'pred_activation': 'sigmoid',
        'interact_type': 'irt_wo_disc',
    }
    def __init__(self, cfg):
        super().__init__(cfg, xavier_init=True)

    def build_cfg(self):
        self.user_count = self.cfg.data_cfg['dt_info']['user_count']
        self.item_count = self.cfg.data_cfg['dt_info']['item_count']
        self.cpt_count = self.cfg.data_cfg['dt_info']['cpt_count']

    def build_model(self):
        self.EncoderUser = MLP(
            input_dim=self.item_count,
            output_dim=self.cpt_count * 2,
            dnn_units=self.model_cfg['EncoderUserHidden']
        )
        self.EncoderItem = MLP(
            input_dim=self.user_count,
            output_dim=self.cpt_count,
            dnn_units=self.model_cfg['EncoderItemHidden']
        )

        self.EncoderItemDiff = MLP(
            input_dim=self.user_count,
            output_dim=self.cpt_count * 2,
            dnn_units=self.model_cfg['EncoderItemHidden']
        )


        self.ItemDisc = nn.Embedding(self.item_count, 1)
        self.pd_net = PosMLP(
            input_dim=self.cpt_count, output_dim=1, activation=self.model_cfg['pred_activation'],
            dnn_units=self.model_cfg['pred_dnn_units'], dropout_rate=self.model_cfg['pred_dropout_rate']
        )

        self.margin_loss_zero_one = MarginLossZeroOne(reduction='none', margin=self.model_cfg['align_margin_loss_kwargs']['margin'])

        self.user_dist = NormalDistUtil()
        self.item_dist = BernoulliUtil(p=self.model_cfg['bernoulli_prior_p'], stgradient=True)
        self.item_dist_diff = NormalDistUtil()
    
    def get_align_item_loss(self, item_emb, item_idx):
        if self.model_cfg['align_type'] == 'mse_margin':
            flag = self.Q_mat[item_idx, :].sum(dim=1) > 0
            left_emb = item_emb[~flag]
            p = self.model_cfg['align_margin_loss_kwargs']['norm']
            t_loss = torch.norm(left_emb, dim=0, p=p).pow(p).sum()
            if left_emb.shape[0] != 0 and self.callback_list.curr_epoch >= self.model_cfg['align_margin_loss_kwargs']['start_epoch']:
                # topk_idx = torch.topk(left_emb, self.model_cfg['align_margin_loss_kwargs']['topk']).indices
                # bottomk_idx = torch.ones_like(left_emb).scatter(1, topk_idx, 0).nonzero()[:, 1].reshape(-1, left_emb.size(1) - topk_idx.size(1))
                # pos = torch.gather(left_emb, 1, topk_idx[:,[-1]])
                # neg = torch.gather(left_emb, 1, bottomk_idx[:,torch.randperm(bottomk_idx.shape[1],dtype=torch.long)[0:int(bottomk_idx.shape[1]*0.5)]])
                topk_idx = torch.topk(left_emb, self.model_cfg['align_margin_loss_kwargs']['topk']+1).indices
                pos = torch.gather(left_emb, 1, topk_idx[:,0:self.model_cfg['align_margin_loss_kwargs']['d1']])
                neg = torch.gather(left_emb, 1, topk_idx[:,[-1]])
                margin_loss = self.margin_loss_zero_one(pos, neg).mean(dim=1).sum()
            else:
                margin_loss = torch.tensor(0.0).to(self.device)
            return {
                "mse_loss": F.mse_loss(item_emb[flag], self.Q_mat[item_idx[flag], :].float(), reduction='sum'),
                "margin_loss": margin_loss,
                "norm_loss": t_loss,
            }
        elif self.model_cfg['align_type'] == 'mse_margin_mean':
            flag = self.Q_mat[item_idx, :].sum(dim=1) > 0
            left_emb = item_emb[~flag]
            p = self.model_cfg['align_margin_loss_kwargs']['norm']
            t_loss = torch.norm(left_emb, dim=0, p=p).pow(p).sum()
            if left_emb.shape[0] != 0 and self.callback_list.curr_epoch >= self.model_cfg['align_margin_loss_kwargs']['start_epoch']:
                # topk_idx = torch.topk(left_emb, self.model_cfg['align_margin_loss_kwargs']['topk']).indices
                # bottomk_idx = torch.ones_like(left_emb).scatter(1, topk_idx, 0).nonzero()[:, 1].reshape(-1, left_emb.size(1) - topk_idx.size(1))
                # pos = torch.gather(left_emb, 1, topk_idx[:,[-1]])
                # neg = torch.gather(left_emb, 1, bottomk_idx[:,torch.randperm(bottomk_idx.shape[1],dtype=torch.long)[0:int(bottomk_idx.shape[1]*0.5)]])
                topk_idx = torch.topk(left_emb, self.model_cfg['align_margin_loss_kwargs']['topk']+1).indices
                bottomk_idx = torch.topk(-left_emb, left_emb.shape[1] -  self.model_cfg['align_margin_loss_kwargs']['topk']).indices
                pos = torch.gather(left_emb, 1, topk_idx[:,0:self.model_cfg['align_margin_loss_kwargs']['d1']])
                neg = torch.gather(left_emb, 1, bottomk_idx).mean(dim=1)
                margin_loss = self.margin_loss_zero_one(pos, neg).mean(dim=1).sum()
            else:
                margin_loss = torch.tensor(0.0).to(self.device)
            return {
                "mse_loss": F.mse_loss(item_emb[flag], self.Q_mat[item_idx[flag], :].float(), reduction='sum'),
                "margin_loss": margin_loss,
                "norm_loss": t_loss,
            }
        else:
            raise ValueError(f"Unknown align type: {self.model_cfg['align_type']}")


    def decode(self, user_emb, item_emb, item_emb_diff, item_id, **kwargs):
        if self.model_cfg['interact_type'] == 'irt_wo_disc':
            return ((user_emb - item_emb_diff)*item_emb).sum(dim=1)
        elif self.model_cfg['interact_type'] == 'irt':
            item_disc = self.ItemDisc(item_id).sigmoid() #* self.model_cfg['disc_scale']
            return ((user_emb - item_emb_diff)*item_emb*item_disc).sum(dim=1)
        elif self.model_cfg['interact_type'] == 'ncdm':
            item_disc = self.ItemDisc(item_id).sigmoid()# * self.model_cfg['disc_scale']
            input = (user_emb - item_emb_diff)*item_emb*item_disc
            return self.pd_net(input).flatten()
        elif self.model_cfg['interact_type'] == 'mf':
            return ((user_emb.sigmoid()*item_emb)*(item_emb*item_emb_diff)).sum(dim=1)
        elif self.model_cfg['interact_type'] == 'mirt': # 就是mf加了个disc
            item_disc = self.ItemDisc(item_id).sigmoid() #* self.model_cfg['disc_scale']
            return ((user_emb.sigmoid()*item_emb)*(item_emb*item_emb_diff)).sum(dim=1) + item_disc.flatten()
        else:
            raise NotImplementedError

    def set_dict_cpt_affiliation(self, aff):
        self.dict_cpt_affiliation = {
            k:torch.LongTensor(v).to(self.device) for k,v in aff.items()
        }

    def forward(self, users, items, labels):
        user_unique, user_unique_idx = users.unique(sorted=True, return_inverse=True)
        item_unique, item_unique_idx = items.unique(sorted=True, return_inverse=True)

        user_mix = self.EncoderUser(self.interact_mat[user_unique, :])
        user_mu, user_logvar = torch.chunk(user_mix, 2, dim=-1)
        user_emb_ = self.user_dist.sample(user_mu, user_logvar)
        user_emb = user_emb_[user_unique_idx, :]

        item_mu = self.EncoderItem(self.interact_mat[:, item_unique].T).sigmoid()
        item_emb_ = self.item_dist.sample(None, item_mu, type_=self.model_cfg['b_sample_type'], **self.model_cfg['b_sample_kwargs'])
        item_emb = item_emb_[item_unique_idx, :]


        item_diff_mix = self.EncoderItemDiff(self.interact_mat[:, item_unique].T)
        item_mu_diff, item_logvar_diff = torch.chunk(item_diff_mix, 2, dim=-1)
        item_emb_diff_ = self.item_dist_diff.sample(item_mu_diff, item_logvar_diff)
        item_emb_diff = item_emb_diff_[item_unique_idx, :]

        loss_main = F.binary_cross_entropy_with_logits(self.decode(user_emb, item_emb, item_emb_diff, item_id=items), labels, reduction='sum') # 重构 loss
        align_loss_dict = self.get_align_item_loss(item_mu, item_unique)
        # align_loss_dict_diff = self.get_align_item_loss(item_mu_diff, item_unique)

        user_terms = self.get_tcvae_terms(user_emb_, params=(user_mu, user_logvar), dist=self.user_dist, dataset_size=self.user_count)
        item_terms = self.get_tcvae_terms(item_emb_, params=item_mu, dist=self.item_dist, dataset_size=self.item_count)
        item_terms_diff = self.get_tcvae_terms(item_emb_diff_, params=(item_mu_diff, item_logvar_diff), dist=self.item_dist_diff, dataset_size=self.item_count)
        
        return {
            'loss_main': loss_main * self.model_cfg['lambda_main'],
            'loss_mse': align_loss_dict['mse_loss'] * self.model_cfg['lambda_q'],
            'loss_margin': align_loss_dict['margin_loss'] * self.model_cfg['align_margin_loss_kwargs']['margin_lambda'],
            'loss_norm': align_loss_dict['norm_loss'] * self.model_cfg['align_margin_loss_kwargs']['norm_lambda'],
            # 'loss_mse_diff': align_loss_dict_diff['mse_loss'] * self.model_cfg['lambda_q'],
            # 'loss_margin_diff': align_loss_dict_diff['margin_loss'] * self.model_cfg['align_margin_loss_kwargs']['margin_lambda'],
            # 'loss_norm_diff': align_loss_dict_diff['norm_loss'] * self.model_cfg['align_margin_loss_kwargs']['norm_lambda'],
            
            'user_MI': user_terms['MI'] * self.model_cfg['alpha_user'],
            'user_TC': user_terms['TC'] * self.model_cfg['beta_user'],
            'user_TC_G': user_terms['TC_G'] * self.model_cfg['g_beta_user'],
            'user_KL': user_terms['KL'] * self.model_cfg['gamma_user'],
            'item_MI': item_terms['MI'] * self.model_cfg['alpha_item'],
            'item_TC': item_terms['TC'] * self.model_cfg['beta_item'],
            'item_TC_G': item_terms['TC_G'] * self.model_cfg['g_beta_item'],
            'item_KL': item_terms['KL'] * self.model_cfg['gamma_item'],
            
            'item_MI_diff': item_terms_diff['MI'] * self.model_cfg['alpha_item'],
            'item_TC_diff': item_terms_diff['TC'] * self.model_cfg['beta_item'],
            'item_TC_G_diff': item_terms_diff['TC_G'] * self.model_cfg['g_beta_item'],
            'item_KL_diff': item_terms_diff['KL'] * self.model_cfg['gamma_item'],
        }

    def get_tcvae_terms(self, z, params, dist, dataset_size):
        batch_size, latent_dim = z.shape

        if isinstance(dist, NormalDistUtil):
            mu, logvar = params
            zero = torch.FloatTensor([0.0]).to(self.device)
            logpz = dist.log_density(X=z, MU=zero, LOGVAR=zero).sum(dim=1)
            logqz_condx = dist.log_density(X=z, MU=mu, LOGVAR=logvar).sum(dim=1)
            _logqz = dist.log_density(
                z.reshape(batch_size, 1, latent_dim),
                mu.reshape(1, batch_size, latent_dim),
                logvar.reshape(1, batch_size, latent_dim)
            ) # _logqz的第(i,j,k)个元素, P(z(n_i)_k|n_j)
        elif isinstance(dist, BernoulliUtil):
            logpz = dist.log_density(z, params=None).sum(dim=1)
            logqz_condx = dist.log_density(z, params=params).sum(dim=1)
            # _logqz = torch.stack([dist.log_density(z, params[i,:]) for i in range(batch_size)],dim=1)
            _logqz = dist.log_density(z.reshape(batch_size, 1, latent_dim), params=params.reshape(1, batch_size, latent_dim), is_check=False)
        else:
            raise ValueError("unknown base class of dist")

        if self.model_cfg['sampling_type'] == 'mws':
            # minibatch weighted sampling
            logqz_prodmarginals = (torch.logsumexp(_logqz, dim=1, keepdim=False) - math.log(batch_size * dataset_size)).sum(1)
            logqz = (torch.logsumexp(_logqz.sum(dim=2), dim=1, keepdim=False) - math.log(batch_size * dataset_size))
            logqz_group_list = []
            if hasattr(self, 'dict_cpt_affiliation'):
                for gid, group_idx in self.dict_cpt_affiliation.items():
                    logqz_group_list.append(
                        (torch.logsumexp(_logqz[:,:,group_idx].sum(dim=2), dim=1, keepdim=False) - math.log(batch_size * dataset_size))
                    )
                logqz_group = torch.vstack(logqz_group_list).T.sum(dim=1)
        elif self.model_cfg['sampling_type'] == 'mss':
            logiw_mat = self._log_importance_weight_matrix(z.shape[0], dataset_size).to(z.device)
            logqz = torch.logsumexp(
                logiw_mat + _logqz.sum(dim=-1), dim=-1
            )  # MMS [B]
            logqz_prodmarginals = (
                torch.logsumexp(
                    logiw_mat.reshape(z.shape[0], z.shape[0], -1) + _logqz,
                    dim=1,
                )
            ).sum(
                dim=-1
            )
            logqz_group_list = []
            if hasattr(self, 'dict_cpt_affiliation'):
                for gid, group_idx in self.dict_cpt_affiliation.items():
                    logqz_group_list.append(
                       (
                        torch.logsumexp(
                            logiw_mat.reshape(z.shape[0], z.shape[0], -1) + _logqz[:,:,group_idx], dim=1,
                        )).sum(dim=-1)
                    )
                logqz_group = torch.vstack(logqz_group_list).T.sum(dim=1)

        else:
            raise ValueError("Unknown Sampling Type")
        
        IndexCodeMI = logqz_condx - logqz
        TC = logqz - logqz_prodmarginals
        TC_G = (logqz - logqz_group).mean() if hasattr(self, 'dict_cpt_affiliation') else torch.FloatTensor([0.0]).to(self.device)
        DW_KL = logqz_prodmarginals - logpz
        return {
            'MI': IndexCodeMI.mean(),
            'TC': TC.mean(),
            'TC_G': TC_G,
            'KL': DW_KL.mean()
        }

    def fit(self, train_dataset, val_dataset=None,callbacks: Sequence[Callback]=()):
        if not hasattr(self, 'dict_cpt_affiliation'):
            assert self.model_cfg['g_beta_user'] == 0.0 
            assert self.model_cfg['g_beta_item'] == 0.0 
        lr = self.train_cfg['lr']
        epoch_num = self.train_cfg['epoch_num']
        batch_size = self.train_cfg['batch_size']
        num_workers = self.train_cfg['num_workers']
        eval_batch_size = self.train_cfg['eval_batch_size']
        weight_decay = self.train_cfg['weight_decay']
        eps = self.train_cfg['eps']

        model = self.train()
        optimizer = self._get_optim(optimizer=self.train_cfg['optim'], lr=lr, weight_decay=weight_decay, eps=eps)
        self.optimizer = optimizer

        train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, num_workers=num_workers)
        if val_dataset is not None:
            val_loader = DataLoader(val_dataset, shuffle=False, batch_size=eval_batch_size, num_workers=num_workers)

        callback_list = CallbackList(callbacks=callbacks, model=model, logger=self.logger)
        self.callback_list = callback_list
        callback_list.on_train_begin()

        self.interact_mat = train_dataset.interact_mat.to(self.device, dtype=torch.float32)
        self.Q_mat = train_dataset.Q_mat.to(self.device)
        if self.model_cfg['bernoulli_prior_auto']:
            p = self.Q_mat[self.Q_mat.sum(dim=1) != 0].float().mean(dim=0).mean().item()
            self.item_dist = BernoulliUtil(p=p, stgradient=True)
            self.logger.info(f"[bernoulli_prior_auto]: {p=}")

        for epoch in range(epoch_num):
            callback_list.on_epoch_begin(epoch + 1)
            logs = defaultdict(lambda: np.full((len(train_loader),), np.nan, dtype=np.float32))
            for batch_id, batch in enumerate(
                    tqdm(train_loader, ncols=self.environ_cfg['tqdm_ncols'], desc="[EPOCH={:03d}]".format(epoch + 1))
            ):
                batch = batch.to(self.device)
                users = batch[:, 0]
                items = batch[:, 1]
                labels = batch[:, 2].float()
                loss_dict = model(users, items, labels)
                loss = torch.hstack([i for i in loss_dict.values() if i is not None]).sum()
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                for k in loss_dict: logs[k][batch_id] = loss_dict[k].item() if loss_dict[k] is not None else np.nan

            for name in logs: logs[name] = float(np.nanmean(logs[name]))

            if val_dataset is not None:
                val_metrics = self.evaluate(val_loader)
                logs.update({f"val_{metric}": val_metrics[metric] for metric in val_metrics})
            
            # logs.update({f"official_doa": self.get_doa(gt=False)})
            logs.update({f"official_doa_gt": self.get_doa(gt=True)})

            callback_list.on_epoch_end(epoch + 1, logs=logs)
            if self.share_obj_dict.get('stop_training', False):
                break

        callback_list.on_train_end()

    @torch.no_grad()
    def predict(self, users, items):
        user_emb = None
        if users is None:
            user_mix = self.EncoderUser(self.interact_mat)
            user_emb, _ = torch.chunk(user_mix, 2, dim=-1)
        else:
            user_mix = self.EncoderUser(self.interact_mat[users, :])
            user_emb, _ = torch.chunk(user_mix, 2, dim=-1)

        item_emb = None
        if items is None:
            item_emb = self.EncoderItem(self.interact_mat.T).sigmoid()
        else:
            item_emb = self.EncoderItem(self.interact_mat[:, items].T).sigmoid()

        item_emb_diff = None
        if items is None:
            item_emb_diff_mix = self.EncoderItemDiff(self.interact_mat.T)
            item_emb_diff, _ = torch.chunk(item_emb_diff_mix, 2, dim=-1)
        else:
            item_emb_diff_mix = self.EncoderItemDiff(self.interact_mat[:, items].T)
            item_emb_diff, _ = torch.chunk(item_emb_diff_mix, 2, dim=-1)

        return self.decode(user_emb, item_emb, item_emb_diff, item_id=items).sigmoid()

    @torch.no_grad()
    def get_user_emb(self, users=None):
        user_emb = None
        if users is None:
            user_mix = self.EncoderUser(self.interact_mat)
            user_emb, _ = torch.chunk(user_mix, 2, dim=-1)
        else:
            user_mix = self.EncoderUser(self.interact_mat[users, :])
            user_emb, _ = torch.chunk(user_mix, 2, dim=-1)
        
        return user_emb

    @torch.no_grad()
    def get_item_emb(self, items=None):
        item_emb = None
        if items is None:
            item_emb = self.EncoderItem(self.interact_mat.T)
        else:
            item_emb = self.EncoderItem(self.interact_mat[:, items].T)
        
        return item_emb.sigmoid()

    @torch.no_grad()
    def evaluate(self, loader):
        self.eval()
        pd_list = list(range(len(loader)))
        gt_list = list(range(len(loader)))
        for idx, batch in enumerate(tqdm(loader, ncols=self.environ_cfg['tqdm_ncols'], desc="[PREDICT]")):
            batch = batch.to(self.device)
            u = batch[:, 0]
            i = batch[:, 1]
            r = batch[:, 2]
            pd_list[idx] = self.predict(u, i).flatten()
            gt_list[idx] = r.flatten()
        y_pd = tensor2npy(torch.hstack(pd_list))
        y_gt = tensor2npy(torch.hstack(gt_list))
        eval_result = {
            metric: self._get_metrics(metric)(y_gt, y_pd) for metric in self.eval_cfg['metrics']
        }
        return eval_result


    def get_doa(self, gt=False):
        user_emb = tensor2npy(self.get_user_emb())

        if not gt:
            raise NotImplementedError
            df_Q = self.df_Q_eval
            df_interact = self.df_interact
        else:
            df_Q = self.df_Q_final
            df_interact = self.df_interact_final
        
        df_user = pd.DataFrame.from_dict({uid:str(list(user_emb[uid, :])) for uid in range(user_emb.shape[0])}, orient='index', columns=['theta']).reset_index().rename(columns={'index': 'uid'})
        df_user['theta'] = df_user['theta'].apply(lambda x: eval(x))
        df = df_interact.merge(df_user, on=['uid']).merge(df_Q, on=['iid'])
        df = df.rename(columns={"uid": 'user_id', 'iid':'item_id', 'label': 'score'})
        official_doa = doa_report(df)
        return float(official_doa['doa'])

    def _log_importance_weight_matrix(self, batch_size, dataset_size):
        """Compute importance weigth matrix for MSS
        Code from (https://github.com/rtqichen/beta-tcvae/blob/master/vae_quant.py)
        """

        N = dataset_size
        M = batch_size - 1
        strat_weight = (N - M) / (N * M)
        W = torch.Tensor(batch_size, batch_size).fill_(1 / M)
        W.view(-1)[:: M + 1] = 1 / N
        W.view(-1)[1 :: M + 1] = strat_weight
        W[M - 1, 0] = strat_weight
        return W.log()
