import torch.nn as nn
import torch
from sklearn.metrics import mean_squared_error, roc_auc_score, accuracy_score, f1_score, label_ranking_loss, coverage_error
import numpy as np
from abc import abstractmethod
import logging
from lib.utils import UnifyConfig
from .init import xavier_normal_initialization


class BaseModel(nn.Module):
    def __init__(self, cfg, xavier_init=True):
        super(BaseModel, self).__init__()
        self.cfg: UnifyConfig = cfg
        self.data_cfg: UnifyConfig = cfg.data_cfg
        self.eval_cfg: UnifyConfig = cfg.eval_cfg
        self.train_cfg: UnifyConfig = cfg.train_cfg
        self.environ_cfg: UnifyConfig = cfg.environ_cfg
        self.model_cfg: UnifyConfig = cfg.model_cfg
        self.logger: logging.Logger = cfg.logger

        self.device = self.environ_cfg['device']
        self.logger = self.cfg.logger
        self.share_obj_dict = {
            "stop_training": False
        }

        self.build_cfg()
        self.build_model()
        if xavier_init:
            self.apply(xavier_normal_initialization)
        self.init_params()
        self.to(self.device)

    @abstractmethod
    def build_cfg(self):
        """
            construct model config
        """
        raise NotImplementedError

    @abstractmethod
    def build_model(self):
        """
            construct model component
        """
        raise NotImplementedError

    def init_params(self):
        pass

    def _get_optim(self, optimizer, lr, weight_decay=0.0, eps=1e-8):
        if isinstance(optimizer, str):
            if optimizer == "sgd":
                optim = torch.optim.SGD(self.parameters(), lr=lr, weight_decay=weight_decay, eps=eps)
            elif optimizer == "adam":
                optim = torch.optim.Adam(self.parameters(), lr=lr, weight_decay=weight_decay, eps=eps)
            elif optimizer == "adagrad":
                optim = torch.optim.Adagrad(self.parameters(), lr=lr, weight_decay=weight_decay, eps=eps)
            elif optimizer == "rmsprop":
                optim = torch.optim.RMSprop(self.parameters(), lr=lr, weight_decay=weight_decay, eps=eps)
            else:
                raise NotImplementedError
        else:
            optim = optimizer
        return optim

    def _get_metrics(self, metric):
        if isinstance(metric, str):
            if metric == "auc":
                return roc_auc_score
            elif metric == "uauc":
                return self.uauc
            elif metric == "mse":
                return mean_squared_error
            elif metric == 'rmse':
                return lambda y_gt, y_pd: mean_squared_error(y_gt, y_pd) ** 0.5
            elif metric == "acc":
                return lambda y_gt, y_pd: accuracy_score(y_gt, np.where(y_pd >= 0.5, 1, 0))
            elif metric == "f1_macro":
                return lambda y_gt, y_pd: f1_score(y_gt, y_pd, average='macro')
            elif metric == "f1_micro":
                return lambda y_gt, y_pd: f1_score(y_gt, y_pd, average='micro')
            elif metric == "ranking_loss":
                return lambda y_gt, y_pd: label_ranking_loss(y_gt, y_pd)
            elif metric == 'coverage_error':
                return lambda y_gt, y_pd: coverage_error(y_gt, y_pd)
            elif metric == 'samples_auc':
                return lambda y_gt, y_pd: roc_auc_score(y_gt, y_pd, average='samples')
            else:
                raise NotImplementedError
        else:
            return metric

    @staticmethod
    def uauc(u, gt, pd):
        uids = np.unique(u)
        return np.array([roc_auc_score(gt[u==uid], pd[u==uid]) for uid in uids]).mean()

