# -*- coding:utf-8 -*- 
from transformers import RobertaModel, BertPreTrainedModel, RobertaConfig
from transformers.modeling_roberta import RobertaLMHead
import torch
import torch.nn as nn
import torch.nn.functional as F
from .grl import WarmStartGradientReverseLayer

ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP = {
    "roberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-pytorch_model.bin",
    "roberta-large": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-pytorch_model.bin",
    "roberta-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-pytorch_model.bin",
    "distilroberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/distilroberta-base-pytorch_model.bin",
    "roberta-base-openai-detector": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-openai-detector-pytorch_model.bin",
    "roberta-large-openai-detector": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-openai-detector-pytorch_model.bin",
}

class RobertaForTokenClassification_Modified(BertPreTrainedModel):
    config_class = RobertaConfig
    pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
    base_model_prefix = "roberta"

    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.roberta = RobertaModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.lm_head = RobertaLMHead(config) #? for contextualized augumentation
        self.classifier = nn.Linear(config.hidden_size, config.num_labels-1) #? for type logits
        self.bin_classifier = nn.Linear(config.hidden_size, 1) #? for binary logits

        self.pseudo_head = nn.Linear(config.hidden_size, config.num_labels-1)
        self.bin_pseudo_head = nn.Linear(config.hidden_size, 1)
        self.grl_layer = WarmStartGradientReverseLayer(alpha=1.0, lo=0.0, hi=0.1, max_iters=1000, auto_step=False)
        self.bin_grl_layer = WarmStartGradientReverseLayer(alpha=1.0, lo=0.0, hi=0.1, max_iters=1000, auto_step=False)
        self.adv_head = nn.Linear(config.hidden_size, config.num_labels-1)
        self.bin_adv_head = nn.Linear(config.hidden_size, 1)

        self.init_weights()
        for param in self.lm_head.parameters():
            param.requires_grad = False

    def forward(self,input_ids,attention_mask,valid_pos,train=False):
        sequence_output = self.roberta(input_ids,attention_mask=attention_mask)[0]
        valid_output = sequence_output[valid_pos > 0]
        sequence_output = self.dropout(valid_output)
        logits = self.classifier(sequence_output)
        bin_logits = self.bin_classifier(sequence_output)

        entity_prob = torch.sigmoid(bin_logits)
        non_type_prob = 1 - entity_prob
        type_prob = F.softmax(logits, dim=-1) * entity_prob
        total_prob = torch.cat([non_type_prob, type_prob], dim=-1)

        type_pred = F.softmax(logits, dim=-1)
        type_pred_label = torch.argmax(type_pred, dim=-1)
        bin_pred = torch.cat((1-entity_prob, entity_prob), dim=-1)
        bin_pred_label = torch.argmax(bin_pred, dim=-1)
        
        if train:
            bin_adv = self.bin_grl_layer(sequence_output)
            type_adv = self.grl_layer(sequence_output)
            bin_logits_adv = self.bin_adv_head(bin_adv)
            logits_adv = self.adv_head(type_adv)
            bin_logits_pseudo = self.bin_pseudo_head(sequence_output)
            logits_pseudo = self.pseudo_head(sequence_output)
            return logits, bin_logits, logits_adv, bin_logits_adv, logits_pseudo, bin_logits_pseudo, type_pred_label, bin_pred_label, total_prob
        else:
            return logits, bin_logits, type_pred_label, bin_pred_label, total_prob