from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
import torch.distributed as dist
from transformers import LlamaForCausalLM
from transformers.modeling_outputs import CausalLMOutputWithPast

from src.common.templates import DATA_TYPE_DICT
from src.model.modeling_utils import concat_all_gather


class MLPLayer(nn.Module):
    """
    Head for getting sentence representations over RoBERTa/BERT's CLS representation.
    """

    def __init__(self, hidden_size):
        super().__init__()
        self.dense = nn.Linear(hidden_size, hidden_size)
        self.activation = nn.Tanh()

    def forward(self, features, **kwargs):
        x = self.dense(features)
        x = self.activation(x)

        return x


class ProjectionMLP(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        in_dim = hidden_size
        hidden_dim = hidden_size * 2
        out_dim = hidden_size
        affine = False
        list_layers = [
            nn.Linear(in_dim, hidden_dim, bias=False),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
        ]
        list_layers += [
            nn.Linear(hidden_dim, out_dim, bias=False),
            nn.BatchNorm1d(out_dim, affine=affine),
        ]
        self.net = nn.Sequential(*list_layers)

    def forward(self, x):
        return self.net(x)


class Pooler(nn.Module):
    def __init__(self, pooler_type, hidden_size, pad_token_id, use_bn=False) -> None:
        super(Pooler, self).__init__()
        self.pooler_type = pooler_type
        self.pad_token_id = pad_token_id
        if self.pooler_type == "last" or self.pooler_type == "mean":
            self.classification_head = (
                MLPLayer(hidden_size) if not use_bn else ProjectionMLP(hidden_size)
            )

    def forward(self, hidden_states, token_mask):
        device = hidden_states.device
        batch_size = hidden_states.shape[0]
        if self.pooler_type == "last":
            sequence_lengths = (torch.ne(token_mask, 0).sum(-1) - 1).to(device)
            sentence_representation = hidden_states[
                torch.arange(batch_size, device=device), sequence_lengths
            ]
            features = self.classification_head(sentence_representation)
        elif self.pooler_type.startswith("mean"):
            token_mask = token_mask.unsqueeze(-1)
            sentence_representation = torch.sum(
                hidden_states * token_mask, dim=1
            ) / torch.sum(token_mask, dim=1)
            if self.pooler_type == "mean":
                features = self.classification_head(sentence_representation)

        # elif self.pooler_type.startswith("mean"):
        #     token_mask = token_mask.unsqueeze(-1)
        #     if self.pooler_type == "mean":
        #         hidden_states = self.classification_head(hidden_states)
        #     features = torch.sum(hidden_states * token_mask, dim=1) / torch.sum(
        #         token_mask, dim=1
        #     )
        else:
            raise NotImplementedError("Not implemented mean_pooling")
        return features


class ContraCLMSeqLoss(nn.Module):
    def __init__(self, temperature=0.05):
        super(ContraCLMSeqLoss, self).__init__()
        self.temperature = temperature
        print(f"Sequence-Level Contrastive Loss:\t temperature: {temperature}")

    def forward(self, pos_hidden_states, neg_hidden_states=None):
        batch_size = pos_hidden_states.shape[0]
        device = pos_hidden_states.device
        y_true = torch.cat(
            [
                torch.arange(
                    1, batch_size, step=2, dtype=torch.long, device=device
                ).unsqueeze(1),
                torch.arange(
                    0, batch_size, step=2, dtype=torch.long, device=device
                ).unsqueeze(1),
            ],
            dim=1,
        ).reshape(
            [
                batch_size,
            ]
        )
        pos_score = torch.matmul(pos_hidden_states, pos_hidden_states.transpose(0, 1))
        # pos_score += torch.eye(batch_size).to(device) * torch.finfo(pos_score.dtype).min
        pos_score -= torch.eye(batch_size).to(device) * 5e5
        if neg_hidden_states is not None:
            neg_score = torch.matmul(
                pos_hidden_states, neg_hidden_states.transpose(0, 1)
            )
            score = torch.cat([pos_score, neg_score], dim=-1)
        else:
            score = pos_score
        score /= self.temperature
        loss_fct = nn.CrossEntropyLoss()
        loss = loss_fct(score, y_true)
        return loss


class ContraCrossLoss(nn.Module):
    def __init__(self, temperature=0.05, num_modals=2):
        super(ContraCrossLoss, self).__init__()
        self.temperature = temperature
        self.num_modals = num_modals
        print(f"Sequence-Level Contrastive Loss:\t temperature: {temperature}")

    def forward(self, pos_logits, neg_logits):
        batch_size = pos_logits.shape[0]
        pos_logits = pos_logits.view(
            batch_size // self.num_modals, self.num_modals, -1
        ).contiguous()
        neg_logits = neg_logits.view(
            batch_size // self.num_modals, self.num_modals, -1
        ).contiguous()
        cot_pos_logits, pal_pos_logits = pos_logits[:, 0], pos_logits[:, 1]
        cot_neg_logits, pal_neg_logits = neg_logits[:, 0], neg_logits[:, 1]
        cot_score = (
            cot_pos_logits
            @ torch.cat([pal_pos_logits, pal_neg_logits], dim=0).T
            / self.temperature
        )
        pal_score = (
            pal_pos_logits
            @ torch.cat([cot_pos_logits, cot_neg_logits], dim=0).T
            / self.temperature
        )
        labels = torch.arange(
            cot_pos_logits.shape[0], device=cot_pos_logits.device, dtype=torch.long
        )
        loss_fct = nn.CrossEntropyLoss()
        cot_loss = loss_fct(cot_score, labels)
        pal_loss = loss_fct(pal_score, labels)
        return (cot_loss + pal_loss) / 2


class ContraLlamaForCausalLM(LlamaForCausalLM):
    def __init__(self, config, *model_args, **model_kargs):
        super().__init__(config)
        self.model_args = model_kargs["model_args"]
        self.data_types = DATA_TYPE_DICT[self.model_args.cl_data_type]
        self.pooler = Pooler(
            self.model_args.pooler_type,
            self.config.hidden_size,
            self.config.pad_token_id,
            self.model_args.use_bn,
        )
        self.single_modal = self.model_args.single_modal
        if len(self.data_types) > 1 and not self.model_args.single_modal:
            # self.temp = nn.Parameter(0.07 * torch.ones([]))
            self.temp = self.model_args.temp
        else:
            self.cl_loss_fct = (
                ContraCLMSeqLoss(self.model_args.temp)
                if len(self.data_types) == 1 or self.model_args.single_modal
                else ContraCrossLoss(self.model_args.temp, len(self.data_types))
            )
        self.cl_weight = self.model_args.cl_weight
        self.rank_weight = self.model_args.rank_weight
        self.cl_length_penalty = self.model_args.cl_length_penalty
        self.rank_margin = self.model_args.rank_margin
        self.init_weights()

    def gather_tensor(self, tensor):
        tensor_list = [torch.zeros_like(tensor) for _ in range(dist.get_world_size())]
        dist.all_gather(tensor_list=tensor_list, tensor=tensor.contiguous())
        tensor_list[dist.get_rank()] = tensor
        return torch.cat(tensor_list, 0)

    def gather_logits(self, logits, output_ids, token_mask):
        probs = (
            torch.gather(logits, dim=-1, index=output_ids[:, :, None]).squeeze(-1)
            * token_mask
        )
        return probs

    def get_score(self, probs, token_mask):
        scores = probs.sum(-1) / (token_mask.sum(-1) ** self.cl_length_penalty)
        return scores

    def compute_margin_rank_loss(self, scores, rw_scores, pos_batch_size):
        neg_batch_size = scores.shape[0] - pos_batch_size
        pos_scores = scores[:pos_batch_size].contiguous().view(neg_batch_size, -1)
        neg_scores = scores[pos_batch_size:]
        rank_loss_fn = nn.MarginRankingLoss(margin=self.rank_margin)
        sign = 1 - rw_scores[pos_batch_size:]
        rank_loss = rank_loss_fn(pos_scores[:, 0], neg_scores, sign)
        if pos_scores.size(1) > 1:
            rank_loss += rank_loss_fn(pos_scores[:, 1], neg_scores, sign)
        return rank_loss

    def compute_detach_rank_loss(self, scores, rw_scores, pos_batch_size):
        neg_batch_size = scores.shape[0] - pos_batch_size
        pos_scores = scores[:pos_batch_size].contiguous().view(neg_batch_size, -1)
        neg_scores = scores[pos_batch_size:].detach()
        # rank_loss_fn = nn.MarginRankingLoss(margin=self.rank_margin)
        rank_loss_fn = nn.CrossEntropyLoss()
        # sign = 1 - rw_scores[pos_batch_size:]
        sign = rw_scores[pos_batch_size:]
        sign[sign == 1] = -100
        rank_loss = rank_loss_fn(
            torch.cat(
                [pos_scores[:, 0].unsqueeze(-1), neg_scores.unsqueeze(-1)], dim=-1
            ),
            sign,
        )
        if pos_scores.size(1) > 1:
            rank_loss += rank_loss_fn(
                torch.cat(
                    [pos_scores[:, 1].unsqueeze(-1), neg_scores.unsqueeze(-1)], dim=-1
                ),
                sign,
            )
        return rank_loss

    def compute_rank_loss(self, scores, rw_scores):
        diff = scores.unsqueeze(0) - scores.unsqueeze(-1)  # b * b
        rw_diff = rw_scores.unsqueeze(0) - rw_scores.unsqueeze(-1)  # b * b
        aval = torch.bitwise_and(rw_diff > 0, diff < 0)
        num_eles = aval.long().sum()
        if num_eles.item() == 0:
            return 0
        # cand = rw_scores.shape[1]
        # new_scores = scores.reshape(-1, cand)   # batch * cand
        # diff = new_scores.unsqueeze(1) - new_scores.unsqueeze(-1) # batch * cand * cand
        # rw_diff = rw_scores.unsqueeze(1) - rw_scores.unsqueeze(-1)
        # aval = torch.bitwise_and(rw_diff > 0, diff < 0)
        return -diff[aval].sum() / num_eles

    def stable_alignment(
        self, logits: torch.Tensor, labels: torch.Tensor, feedback_scores: torch.Tensor
    ) -> torch.Tensor:
        # Calculate the SFT loss
        sorted_ratings, indices = torch.sort(feedback_scores.squeeze(), descending=True)
        best_idx = indices[0] if indices.dim() != 0 else indices.item()
        best_score = sorted_ratings[0] if indices.dim() != 0 else sorted_ratings.item()
        loss_fct = CrossEntropyLoss(ignore_index=-100)

        # Calculate the penalty from low-rating responses.
        batch_losses = []
        for logit, label in zip(logits, labels):
            batch_losses.append(
                loss_fct(logit.view(-1, logits.size(-1)), label.view(-1))
            )
        batch_loss = torch.stack(batch_losses, dim=0)

        pos_batch_size = logits.shape[0] * 2 // 3
        batch_loss[:pos_batch_size]
        neg_loss = batch_loss[pos_batch_size:]

        # Modulate the penalty by the difference in ratings.
        min_loss = batch_loss[best_idx]
        neg_losses = []
        if indices.dim() != 0 and indices.size(-1) > 1:
            for idx in indices[1:]:
                margin = (best_score - sorted_ratings[idx]) * self.rank_margin
                neg_loss = min_loss - batch_loss[idx] + margin
                neg_losses.append(neg_loss)

        if len(neg_losses) > 0:
            neg_losses_ts = torch.stack(neg_losses)
            # if self.args.max_flow:
            #     diff = torch.max(torch.max(neg_losses_ts), torch.tensor(0.0).cuda())
            # else:
            diff = torch.max(
                neg_losses_ts.mean(), torch.tensor(0.0, device=neg_losses_ts.device)
            )
        else:
            diff = torch.tensor(0.0, device=min_loss.device)

        return self.rank_weight * diff

    def compute_cross_cl(self, cot_logits, pal_logits, pos_batch_size):
        pos_cot_logits = cot_logits[:pos_batch_size]
        pos_pal_logits = pal_logits[:pos_batch_size]
        all_cot_logits = concat_all_gather(cot_logits)
        all_pal_logits = concat_all_gather(pal_logits)
        cot_score = pos_cot_logits @ all_pal_logits.T
        cot_score /= self.temp
        pal_score = pos_pal_logits @ all_cot_logits.T
        pal_score /= self.temp
        batch_size = cot_score.shape[0]
        labels = torch.arange(batch_size, device=cot_score.device, dtype=torch.long)
        if self.model_args.no_hard_neg:
            labels += batch_size * dist.get_rank()
        else:
            labels += batch_size * 2 * dist.get_rank()
        return (
            F.cross_entropy(cot_score, labels) + F.cross_entropy(pal_score, labels)
        ) / 2

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        # neg_input_ids=None,
        reward=None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        outputs = super().forward(
            input_ids,
            attention_mask,
            position_ids,
            past_key_values,
            inputs_embeds,
            # labels,
            None,
            use_cache,
            output_attentions,
            True,  # output_hidden_states
            return_dict,
        )
        # pos_batch_size = len(self.data_types)
        if self.model_args.no_hard_neg:
            pos_batch_size = input_ids.shape[0]
        elif len(self.data_types) > 1 or self.model_args.no_add_pos:
            pos_batch_size = input_ids.shape[0] // 2
        else:
            pos_batch_size = input_ids.shape[0] * 2 // 3

        logits, last_hidden_states = outputs.logits, outputs.hidden_states[-1]
        loss = None
        loss_dict = {}
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = logits[:pos_batch_size, :-1, :].contiguous()
            shift_labels = labels[:pos_batch_size, 1:].contiguous()
            # last_hidden_states = last_hidden_states[..., :-1, :].contiguous()
            # token_mask = labels[..., 1:].contiguous().ne(-100)
            if self.pooler.pooler_type == "last":
                token_mask = attention_mask
            else:
                token_mask = labels.ne(-100)
            # attention_mask = attention_mask[..., 1:].contiguous()
            # Flatten the             # Flatten the tokens
            loss_fct = CrossEntropyLoss()
            # Enable model parallelism
            shift_labels = shift_labels.to(shift_logits.device)
            loss = loss_fct(
                shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)
            )
            loss_dict["lm"] = loss.item()

            if self.cl_weight > 0.0:
                pooler_logits = self.pooler(last_hidden_states, token_mask)
                if self.model_args.convert_bf:
                    pooler_logits = pooler_logits.to(torch.float32)
                pooler_logits = F.normalize(pooler_logits, dim=1)
                if self.model_args.convert_bf:
                    pooler_logits = pooler_logits.to(last_hidden_states.dtype)
                if len(self.data_types) > 1 and not self.model_args.single_modal:
                    cot_logits = pooler_logits[::2].contiguous()
                    pal_logits = pooler_logits[1::2].contiguous()
                    cl_loss = self.compute_cross_cl(
                        cot_logits, pal_logits, pos_batch_size // 2
                    )
                else:
                    pos_hidden_states = pooler_logits[:pos_batch_size]
                    neg_hidden_states = (
                        pooler_logits[pos_batch_size:]
                        if pooler_logits.shape[0] != pos_batch_size
                        else None
                    )
                    if dist.is_initialized():
                        pos_hidden_states = self.gather_tensor(pos_hidden_states)
                        if neg_hidden_states is not None:
                            neg_hidden_states = self.gather_tensor(neg_hidden_states)
                    cl_loss = self.cl_loss_fct(pos_hidden_states, neg_hidden_states)
                loss_dict["cl"] = cl_loss.item()
                loss += self.cl_weight * cl_loss

            # if self.class_weight > 0.0:
            #     last_hidden_states
            #     loss_fct = CrossEntropyLoss()
            #     class_loss = loss_fct(
            #         last_hidden_states.view(-1, 2), reward.view(-1)
            #     )
            #     loss_dict["class"] = class_loss.item()
            #     loss += self.class_weight * class_loss

            if self.rank_weight > 0.0:
                shift_logits = logits[..., :-1, :].contiguous()
                shift_labels = labels[..., 1:].contiguous()
                token_mask = shift_labels.ne(-100).float()
                output_ids = input_ids[..., 1:].contiguous()
                probs = self.gather_logits(
                    F.log_softmax(shift_logits, dim=-1), output_ids, token_mask
                )
                scores = self.get_score(probs, token_mask)

                # if self.rank_margin != 0.0:
                #     rank_loss = self.compute_margin_rank_loss(scores, reward)
                # else:
                #     rank_loss = self.compute_rank_loss(scores, reward)
                # rank_loss = self.compute_margin_rank_loss(
                #     scores, reward, pos_batch_size
                # )
                rank_loss = self.compute_detach_rank_loss(
                    scores, reward, pos_batch_size
                )
                # TODO only applicable when bs = 1
                # rank_loss = self.stable_alignment(shift_logits, shift_labels, reward)
                if rank_loss != 0:
                    loss_dict["rank"] = rank_loss.item()
                    loss += self.rank_weight * rank_loss

            if dist.get_rank() == 0:
                print(loss_dict)

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
