from .entrypoint import EntryPoint
from lib.callback import ModelCheckPoint, EarlyStopping, EarlyStoppingPlusValue, History, BaseLogger, TensorboardCallback
from lib.utils import set_same_seeds
from torch.utils.data import DataLoader
import torch
from lib.utils import PathUtil, tensor2npy
import numpy as np
from dataset import VAEDataset
from evaluate import Evaluate, doa_report
from tqdm import tqdm
import shutil
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from collections import defaultdict
import json
import os


class NumpyEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return json.JSONEncoder.default(self, obj)

class ENTRY_EDUVAE_FOLD(EntryPoint):
    def __init__(self, cfg):
        super().__init__(cfg=cfg)

    def start(self):
        datafmt = self.datafmt_cls.from_cfg(self.cfg)
        train_dataset_list = datafmt.train_dataset_list
        val_dataset_list = datafmt.val_dataset_list
        
        fold_exec_num = self.cfg.task_cfg.get('fold_exec_num', 1000)
        
        best_metric_value_dict = defaultdict(list)
        for i, (train_dt, val_dt) in  enumerate(zip(train_dataset_list, val_dataset_list)):
            if (i+1) > fold_exec_num:
                break
            self.logger.info(f"====== [FOLD ID]: {i} ======")
            metrics = self.one_fold(train_dt, val_dt, datafmt, i)
            for m in metrics:
                best_metric_value_dict[m].append(metrics[m])
            self.logger.info("=="*10)
        
        best_metric_value_dict_mean = {}
        for metric_name, val_list in best_metric_value_dict.items():
            best_metric_value_dict_mean[metric_name] = np.mean(val_list)
            self.logger.info(f"All Fold Mean {metric_name} = {best_metric_value_dict_mean[metric_name]}")
        
        with open(f"{self.cfg.temp_folder_path}/result-list.json", 'w', encoding='utf8') as f:
            json.dump(best_metric_value_dict, fp=f, indent=2, ensure_ascii=False, cls=NumpyEncoder)
        
        with open(f"{self.cfg.temp_folder_path}/result-mean.json", 'w', encoding='utf8') as f:
            json.dump(best_metric_value_dict_mean, fp=f, indent=2, ensure_ascii=False, cls=NumpyEncoder)

    def one_fold(self, train_dataset, val_dataset, datafmt, fold_id):
        # 构建callback
        num_stop_rounds = self.train_cfg['num_stop_rounds']
        check_point_metrics = self.train_cfg.get("check_point_metric", None)
        es_metrics = self.train_cfg["early_stop_metric"]
        es_cb_name = self.train_cfg.get('es_cb_name', 'EarlyStopping')
        if es_cb_name == 'EarlyStopping':
            earlystopping = EarlyStopping(es_metrics, num_stop_rounds=num_stop_rounds, start_round=1)
        elif es_cb_name == 'EarlyStoppingPlusValue':
            earlystopping = EarlyStoppingPlusValue(
                es_metrics, val_metric_name=self.train_cfg['val_metric_name'],
                val_metric_dist=self.train_cfg['val_metric_dist'],
                num_stop_rounds=num_stop_rounds
            )
        else:
            raise NotImplementedError
        
        modelCheckPoint = ModelCheckPoint(
            check_point_metrics or es_metrics, save_folder_path=f"{self.cfg.temp_folder_path}/pths/{fold_id}"
        )
        # tensorboard_cb = TensorboardCallback(log_dir=f"{self.cfg.temp_folder_path}/tensorboard/")
        history = History(folder_path=f"{self.cfg.temp_folder_path}/history/{fold_id}", plot_curve=True,filename_suffix=fold_id)
        callbacks = [
            modelCheckPoint, earlystopping, history,
            BaseLogger(self.logger, group_by_contains=['loss', 'user_', 'item_'])
        ]

        # train
        set_same_seeds(self.train_cfg['seed'])
        model = self.model_cls(cfg=self.cfg)

        if hasattr(datafmt, 'dict_cpt_affiliation') and hasattr(model, 'set_dict_cpt_affiliation'):
            model.set_dict_cpt_affiliation(datafmt.dict_cpt_affiliation)

        # model.df_Q = datafmt.df_Q_train
        # model.df_Q_eval = datafmt.df_Q_train_eval
        # model.df_Q_eval['knowledge'] = [datafmt.Q_mat_train_eval[i,:].tolist() for i in model.df_Q_eval['iid']]
        # model.df_interact = pd.concat([datafmt.df_interaction_train, datafmt.df_interaction_val], axis=0).reset_index(drop=True)
        # print(model.df_interact)

        #方便得到真实official doa
        model.df_Q_final = datafmt.df_Q
        model.df_Q_final['knowledge'] = [datafmt.Q_mat[i,:].tolist() for i in model.df_Q_final['iid']]
        model.df_interact_final = datafmt.df_interaction

        model.fit(train_dataset=train_dataset, val_dataset=val_dataset, callbacks=callbacks)


        # load best params
        metric_name = self.cfg['inf_cfg']['metric']
        metric = [m for m in modelCheckPoint.metric_list if m.name == metric_name][0]
        fpth =  f"{self.cfg.temp_folder_path}/pths/{fold_id}/best-epoch-{metric.best_epoch:03d}-for-{metric.name}.pth"
        model.load_state_dict(torch.load(fpth))
        self.logger.info(f"Load {fpth} !")

        metrics = {}
        for m in modelCheckPoint.metric_list:
            best_his = history.log_as_time[m.best_epoch]
            metrics[f"{m.name}-best-epoch"] = m.best_epoch
            for k in best_his:
                metrics[f"{m.name}-best-epoch-{k}"] = best_his[k]

        if hasattr(model, "get_user_emb"):
            # 得到当前所有user和item的隐表征
            with torch.no_grad():
                user_emb = model.get_user_emb()

            save_folder = f"{self.cfg.temp_folder_path}/embs/{fold_id}/"
            PathUtil.auto_create_folder_path(save_folder)
            user_emb_npy = tensor2npy(user_emb)
            mean_abs_user_corrcoef = self.save_user_emb(user_emb_npy, save_folder)
            if self.cfg['task_cfg'].get('save_emb_include_log', False):
                tau = self.model_cfg['tau']
                tmp_emb = np.log(user_emb_npy) * tau
                PathUtil.auto_create_folder_path(save_folder+"/log/")
                coef = self.save_user_emb(tmp_emb, save_folder+"/log/")
                metrics.update({
                    'log_mean_abs_user_corrcoef': coef
                })

            # 2. 用户表征可解释性评估
            ## official DOA
            df_user = pd.DataFrame.from_dict({uid:str(list(user_emb_npy[uid, :])) for uid in range(user_emb_npy.shape[0])}, orient='index', columns=['theta']).reset_index().rename(columns={'index': 'uid'})
            df_user['theta'] = df_user['theta'].apply(lambda x: eval(x))

            df_Q = datafmt.df_Q
            df_Q['knowledge'] = [datafmt.Q_mat[i,:].tolist() for i in df_Q['iid']]

            df = datafmt.df_interaction.merge(df_user, on=['uid']).merge(df_Q, on=['iid'])
            df = df.rename(columns={"uid": 'user_id', 'iid':'item_id', 'label': 'score'})
            official_doa = doa_report(df)

            metrics.update({
                'official_doa_gt': official_doa['doa'],
                'doa_know_support': official_doa['doa_know_support'],
                'doa_z_support': official_doa['doa_z_support'],
                'mean_abs_user_corrcoef': mean_abs_user_corrcoef,
            })
            for i,v in enumerate(official_doa['doa_list']):
                metrics[f'doa_k_{i}'] = v

        if hasattr(model, "get_item_emb"):
            # 得到当前所有user和item的隐表征
            with torch.no_grad():
                item_emb = model.get_item_emb()

            save_folder = f"{self.cfg.temp_folder_path}/embs/{fold_id}"
            PathUtil.auto_create_folder_path(save_folder)
            item_emb_npy = tensor2npy(item_emb)
            mean_abs_item_corrcoef = self.save_item_emb(item_emb_npy, save_folder)

            if self.cfg['task_cfg'].get('save_emb_include_log', False):
                tau = self.model_cfg['tau']
                tmp_emb = np.log(item_emb_npy) * tau
                PathUtil.auto_create_folder_path(save_folder+"/log/")
                coef = self.save_item_emb(tmp_emb, save_folder+"/log/")
                metrics.update({
                    'log_mean_abs_item_corrcoef': coef
                })

            m = self.evaluate_item_emb(item_emb, datafmt.Q_mat, datafmt, fold_id)

            metrics.update({
                'mean_abs_item_corrcoef': mean_abs_item_corrcoef
            })
            metrics.update(m)
        
        for name in metrics: self.logger.info(f"{name}: {metrics[name]}")
        History.dump_json(metrics, f"{self.cfg.temp_folder_path}/history/{fold_id}/result.json")
        # 删除模型参数文件
        shutil.rmtree(f"{self.cfg.temp_folder_path}/pths/{fold_id}")

        return metrics

    def evaluate_item_emb(self, item_emb, Q, datafmt, fold_id):
        exer_count = item_emb.shape[0]
        Q_u2i = {exer_id: set(tensor2npy(torch.argwhere(Q[exer_id,:]).flatten()).tolist()) for exer_id in range(exer_count)}
        eval = Evaluate(metrics=self.eval_cfg.get('item_align_metrics', ['recall']), topks=self.eval_cfg.get('item_align_topk',[1,2]), device=self.environ_cfg['device'], uni_neg_sample_num=-1)
        items = np.arange(exer_count)
        perf = eval.evaluate(
                uid_list=items, rating_mat=item_emb, except_u2i={exer_id: [] for exer_id in range(exer_count)}, test_u2i=Q_u2i
        )
        mean_perf = np.nanmean(perf.astype("float64"), axis=0)
        ret = {name: mean_perf[i] for i, name in enumerate(eval.metric_names)}

        if self.data_cfg['Q_delete_ratio'] > 0.0:
            preverved_iids = datafmt.missing_df_Q['iid'].unique()
            missing_iids = np.arange(self.data_cfg['dt_info']['item_count'])
            flag = ~np.isin(missing_iids, preverved_iids)
            missing_iids = missing_iids[flag]
            item_emb = item_emb[flag]
            perf = eval.evaluate(
                uid_list=missing_iids, rating_mat=item_emb, except_u2i={exer_id: [] for exer_id in range(exer_count)}, test_u2i=Q_u2i
            )
            mean_perf = np.nanmean(perf.astype("float64"), axis=0)
            ret.update({f"{name}_missing": mean_perf[i] for i, name in enumerate(eval.metric_names)})
        else:
            pass
        return ret

    def save_user_emb(self, user_emb_npy, save_folder):
        if not os.path.exists(save_folder):
            os.makedirs(save_folder)
        if self.cfg['task_cfg']['save_emb']:
            np.save(f"{save_folder}/user_emb.npy", user_emb_npy)

        # user_cov = np.cov(user_emb_npy.T)
        # item_cov = np.cov(item_emb_npy.T)
        user_corrcoef = np.corrcoef(user_emb_npy.T)
        mean_abs_user_corrcoef = np.abs(user_corrcoef).mean()

        plt.figure()
        sns.heatmap(user_corrcoef, cmap='YlGnBu',)
        plt.title(f"user-emb corrcoef")
        plt.savefig(f"{save_folder}/user-emb-corrcoef-heatmap.png", dpi=500, bbox_inches='tight', pad_inches=0.1)

        return mean_abs_user_corrcoef

    def save_item_emb(self, item_emb_npy, save_folder):
        if not os.path.exists(save_folder):
            os.makedirs(save_folder)
        if self.cfg['task_cfg']['save_emb']:
            np.save(f"{save_folder}/item_emb.npy", item_emb_npy)

        item_corrcoef = np.corrcoef(item_emb_npy.T)
        mean_abs_item_corrcoef = np.abs(item_corrcoef).mean()

        plt.figure()
        sns.heatmap(item_corrcoef, cmap='YlGnBu',)
        plt.title(f"item-emb corrcoef")
        plt.savefig(f"{save_folder}/item-emb-corrcoef-heatmap.png", dpi=500, bbox_inches='tight', pad_inches=0.1)
        plt.close('all')
        return mean_abs_item_corrcoef
