import torch
from ..grace.utils import param_subset, get_logits, parent_module, brackets_to_periods
import transformers
import copy


class Defer(torch.nn.Module):
    def __init__(self, config, model, device):
        super(Defer, self).__init__()
        self.config = config
        self.model = model
        layer = config.inner_params[0]
        self.device = device

        # strip weight matrix names from layer names
        suffixes = [".weight", ".bias"]
        self.layer = layer.rsplit(".", 1)[0] if any(layer.endswith(x) for x in suffixes) else layer

        for n, p in self.model.named_parameters():
            p.requires_grad = False

        if isinstance(self.model, transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel):
            transpose = False
        else:
            transpose = True

        # --- Add GRACE to chosen layers ---
        self.edit_module = parent_module(self.model, brackets_to_periods(self.layer))
        self.layer_name = self.layer.rsplit(".", 1)[-1]
        original_layer = getattr(self.edit_module, self.layer_name)
        if type(original_layer) is not DeferAdaptor:
            setattr(self.edit_module, self.layer_name, DeferAdaptor(config, original_layer, transpose=transpose).to(self.device))
            self.original_layer = copy.deepcopy(original_layer)

    def get_inner_layer(self, named_parameters):  # , layer_name):
        params = []
        for n, p in named_parameters:
            if "defer" in n or "predict" in n:
                params += list(p)
        return params

    def generate(self, *args, **kwargs):
        setattr(eval(f"self.model.{self.layer}"), "key_id", -1)
        return self.model.generate(*args, **kwargs)

    def __call__(self, **kwargs):
        # if self.config["experiment"]["task"] == "hallucination":
        #     key_id = (kwargs["labels"] == -100).sum() - 1
        #     setattr(eval(f"self.model.{self.layer}"), "key_id",
        #             key_id)  # Tell GRACE which token to use for its query (default is the last token)
        if 'labels' in kwargs:
            key_id = (kwargs["labels"] == -100).sum() - 1
            setattr(eval(f"self.model.{self.layer}"), "key_id", key_id)
        return self.model(**kwargs)

    def get_params(self, named_parameters, names):
        param_dict = dict(named_parameters)
        return [param_dict[n] for n in names]

    def get_adapter_layer(self):
        adapter_layer = getattr(self.edit_module, self.layer_name)
        assert type(adapter_layer) is DeferAdaptor, print('Adapter Layer is not added correctly....')
        return adapter_layer


    def reset_layer(self):
        layer = getattr(self.edit_module, self.layer_name)
        del layer
        setattr(self.edit_module, self.layer_name, self.get_adapter_layer().original_layer)

    def edit(self, config, tokens):
        key_id = (tokens["labels"] == -100).sum() - 1
        setattr(eval(f"self.model.{self.layer}"), "untrained", False)
        setattr(eval(f"self.model.{self.layer}"), "training", True)
        optimizer = torch.optim.Adam(self.model.parameters(), config.edit_lr, weight_decay=1e-4)
        adapter_layer = getattr(parent_module(self.model, brackets_to_periods(self.layer)), self.layer.rsplit(".", 1)[-1])
        adapter_layer.set_param_trainable()
        self.losses = []

        # Tell the model which token to replace
        setattr(eval(f"self.model.{self.layer}"), "key_id", key_id)

        for i in range(config.n_iter):
            outputs = self.model(**tokens)
            logits, loss = outputs.logits, outputs.loss
            self.losses.append(loss.detach().cpu().numpy())
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            print(loss.item())
        self.loss = loss
        setattr(eval(f"self.model.{self.layer}"), "training", False)


class DeferAdaptor(torch.nn.Module):
    def __init__(self, config, layer, transpose):
        super(DeferAdaptor, self).__init__()

        self.key_id = -1  # Default to using last token
        self.original_layer = copy.deepcopy(layer)
        self.device = layer.weight.device
        self.untrained = True

        # GPT has non-transposed weights
        if transpose:
            input_dim = layer.weight.shape[1]
            output_dim = layer.weight.shape[0]
        else:
            input_dim = layer.weight.shape[0]
            output_dim = layer.weight.shape[1]

        for n, p in layer.named_parameters():
            p.requires_grad = False

        # self.dropout = torch.nn.Dropout(p=0.3)
        self.defer = torch.nn.Linear(input_dim, 1).to(self.device)
        self.predict_values = torch.nn.Linear(input_dim, output_dim).to(self.device)
        self.threshold = 0.5

    def set_param_trainable(self):
        self.defer.requires_grad_(True)
        self.predict_values.requires_grad_(True)
        # self.dropout.requires_grad_(True)

    def forward(self, *args):
        layer_out = self.original_layer(*args)  # Precompute model's prediction

        if self.untrained:
            return layer_out

        token_to_edit = min(self.key_id, args[0].shape[1] - 1)

        query = args[0][:, token_to_edit, :]  # Pull out query for current instance

        defer = torch.sigmoid(self.defer(query))  # If over threshold, DEFER
        self.deferral_val = defer
        values = self.predict_values(query)  # Predict new values from small model
        # values = self.dropout(values)

        if self.training:
            layer_out = defer * values.unsqueeze(1).repeat_interleave(layer_out.shape[1], 1) + (1 - defer) * layer_out
        else:
            layer_out = torch.where((defer >= self.threshold), values, layer_out)
        return layer_out