from dataclasses import dataclass
import eval_glue
from transformers import Trainer, TrainingArguments, TrainerCallback
from collections import defaultdict
import torch
from torch import nn
from param import param

@dataclass
class RegMeanMerge:

    models_to_merge: list
    trainers: list
    data_nums: list
    reduce_non_diag_val: float = 1.0

    def merge(self, ):
        regmean_weight = self.get_coefficient()
        return self.get_merged(regmean_weight)

    @torch.inference_mode()
    def get_coefficient(self, ):

        def comp_weights(mod_name: str):
            
            def hook(mod, inp, out):

                x = inp[0].detach(),
                batch_num = x.shape[0]
                x = x.reshape(-1, x.shape[-1])
                xtx = torch.matmul(x.transpose(0, 1), x)
                
                regmean_weights[mod_name] = (
                    (regmean_weights[mod_name] * num_item[mod_name] + xtx) /
                    (num_item[mod_name] + x.shape[0])
                )
                num_item[mod_name] += x.shape[0]

            return hook

        models_regmean_weights = []
        for name in self.names:
            model, tokenizer = eval_glue.load_glue_classifier(name, )
            train_dataset = eval_glue.load_glue_dataset(tokenizer, name, split='train').select(range(self.data_nums))

            handles = []
            regmean_weights = defaultdict(int)
            num_item = defaultdict(int)

            for n, mod in model.named_modules():
                if isinstance(mod, nn.Linear):
                    handle = mod.register_forward_hook(
                        comp_weights(module_name=n)
                    )
                    handles.append(handle)

            trainer = Trainer(
                model=model,
                args=TrainingArguments(
                    per_device_train_batch_size=16,
                    num_train_epochs=1,
                    report_to=[], # disable wandb
                ),
                train_dataset=train_dataset, 
                tokenizer=tokenizer,
            )
            trainer.train()

            models_regmean_weights.append(regmean_weights)

            for handle in handles:
                handle.remove()
            del model

        return models_regmean_weights

    def get_merged(self, regmean_weights):

        def reduce_non_diag_elems(tensor: torch.Tensor):
            # diagonal = 1 - val
            diag_weights = torch.diag(
                torch.ones(tensor.shape[0]) - self.reduce_non_diag_val
            )
            non_diag_weights = torch.zeros_like(diag_weights).fill_(self.reduce_non_diag_val)
            # diagonal = 1
            mask = (diag_weights + non_diag_weights)
            return tensor * mask

        # for each parameter
        def regmean_process(ps):

            ps, rw = ps[:-1], ps[-1]

            merged_by_regmean = False
            # if n.endswith(".weight") and n.split(".weight")[0] in rw[0].keys():
            p_rw = [
                torch.matmul(reduce_non_diag_elems(_rw), p.transpose(0, 1))
                for _rw, p in zip(rw, ps)
            ]
            _param = torch.matmul(torch.inverse(sum(rw)), sum(p_rw))
            _param = _param.transpose(0, 1)
        
            if not merged_by_regmean:
                # directly average
                _param = torch.stack(ps, dim=0).mean(dim=0)
            
            return _param

        merged_param = param.vectorize_reduce(
            regmean_process,
            self.models_to_merge + [regmean_weights]
        )
        return merged_param
