import torch
import numbers
from torch.nn.parameter import Parameter
from torch.nn import Module
from torch.nn import functional as F
from torch.nn import init
import math
from timm.models.layers import trunc_normal_
from torch import Tensor, Size
from typing import Union, List, Tuple
import utils

def power_up_method(x, power, method='power'):
    res = []
    if method=='power':
        def power_function(xin, i_p):
            return xin**i_p
        return [power_function(x, i) for i in range(power+1)]
    elif method=='chebyscheff':
        y = 2*x-1
        res = [1, y]
        for _ in range(power-1):
            res.append(2*y*res[-1]-res[-2])
        if power == 0:
            return [1]
        return res
    elif method=='legendra':
        y = 2*x-1
        res = [1, y]
        for _ in range(power-1):
            n=len(res)
            res.append((2*n-1)/n*y*res[-1]-(n-1)/n*res[-2])
        return res
    elif method=='fourier':
        y = 2*x-1
        res.append(1)
        res.extend([math.sin(k*math.pi/2.0*y) for k in range(1, power+1)])
        return res #[(i+1)/2.0 for i in res]
    else:
        raise NotImplementedError


class LELayerNorm(Module):
    __constants__ = ['normalized_shape', 'eps', 'elementwise_affine']
    normalized_shape: Tuple[int, ...]
    eps: float
    elementwise_affine: bool

    def __init__(self, iblock=0, serve_module_name='', normalized_shape=None, eps: float = 1e-5, elementwise_affine: bool = True,
                 device=None, dtype=None, power = 5, power_method = 'power', use_power_list=[0,1], gene_dict = None, \
                 learngene_d=False, only_rel_last=False, constraint_d=False) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super(LELayerNorm, self).__init__()
        self.serve_module_name = serve_module_name
        print("serve_module_name: {} and iblock {}".format(self.serve_module_name, iblock))
        if isinstance(normalized_shape, numbers.Integral):
            # mypy error: incompatible types in assignment
            normalized_shape = (normalized_shape,)  # type: ignore[assignment]
        self.normalized_shape = (power+1,) + tuple(normalized_shape)  # type: ignore[arg-type] (12, dim)
        self.eps = eps
        self.elementwise_affine = elementwise_affine
        self.power = power
        self.power_method = power_method
        self.use_power_list = use_power_list
        self.learngene_d = learngene_d

        self.only_rel_last = only_rel_last
        self.constraint_d = constraint_d

        self.pre_weights, self.pre_biases = None, None
        self.weights, self.biases = None, None

        self.device = device
        self._iblock = iblock
        print(self.constraint_d)
        if self.learngene_d:
            if self.only_rel_last:
                assert gene_dict is not None
                print("========================= have learngene 2===============")
                self.pre_weights = Parameter(gene_dict[self.serve_module_name]['weights'].to(dtype).to(self.device))
                self.pre_biases = Parameter(gene_dict[self.serve_module_name]['biases'].to(dtype).to(self.device))
                pre_power = self.pre_weights.shape[0] - 1
                assert pre_power == self.power
                if len(self.use_power_list) == 1:
                    left_shape = tuple(normalized_shape)
                    factory_kwargs = {"requires_grad": True, 'device': device, 'dtype': dtype}
                    self.weights = Parameter(torch.empty(left_shape, **factory_kwargs))
                    self.biases = Parameter(torch.empty(left_shape, **factory_kwargs))
                    
                    p_list = torch.tensor(power_up_method(self._iblock, self.power, self.power_method)).to(self.device)
                    
                    self.weights.data = torch.matmul(p_list[:len(self.use_power_list)], self.pre_weights[self.use_power_list])
                    self.biases.data = torch.matmul(p_list[:len(self.use_power_list)], self.pre_biases[self.use_power_list])
                else:
                    p_list = torch.tensor(power_up_method(self._iblock, self.power, self.power_method)).to(self.device)
                    left_shape = tuple(normalized_shape)
                    factory_kwargs = {"requires_grad": True, 'device': device, 'dtype': dtype}
                    self.unfrozen_weight = Parameter(torch.empty(left_shape, **factory_kwargs))
                    self.unfrozen_bias = Parameter(torch.empty(left_shape, **factory_kwargs))

                    print('this is for layernorm')
                    print(p_list[:len(self.use_power_list)-1].shape)
                    print(self.pre_weights[self.use_power_list[:-1]].shape)
                    print(p_list[len(self.use_power_list)-1].shape)
                    print(self.pre_weights[self.use_power_list[-1]].shape)
                    self.frozen_weight = torch.matmul(p_list[:len(self.use_power_list)-1], self.pre_weights[self.use_power_list[:-1]])
                    self.frozen_bias = torch.matmul(p_list[:len(self.use_power_list)-1], self.pre_biases[self.use_power_list[:-1]])

                    self.frozen_weight = self.frozen_weight.detach()
                    self.frozen_bias = self.frozen_bias.detach()
                    self.frozen_weight.requires_grad = False
                    self.frozen_bias.requires_grad = False

                    self.unfrozen_weight.data = p_list[len(self.use_power_list)-1] * self.pre_weights[self.use_power_list[-1]]
                    self.unfrozen_bias.data = p_list[len(self.use_power_list)-1] * self.pre_biases[self.use_power_list[-1]]
                    print(self.frozen_weight.shape)
                    print(self.frozen_bias.shape)
                    print(self.unfrozen_weight.shape)
                    print(self.unfrozen_bias.shape)
                self.pre_weights.requires_grad = self.pre_biases.requires_grad = False

            elif self.constraint_d:
                assert gene_dict is not None
                if gene_dict is not None:
                    print("========================= have learngene ===============")
                    self.weights = Parameter((gene_dict[self.serve_module_name]['weights'].cpu()[:self.power+1]).to(dtype).to(device))
                    self.biases = Parameter((gene_dict[self.serve_module_name]['biases'].cpu()[:self.power+1]).to(dtype).to(device))

                    pre_power = self.weights.shape[0] - 1
                    print('The previous max power is ', pre_power, ' for', self.serve_module_name, 'and now we use it with constraint!')
                    assert power == pre_power
                else:
                    print("========================= not have learngene ===============")
                    print('have constraint and is for downsteam')
                    self.weights = Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
                    self.biases = Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
            else:
                if gene_dict is not None:
                    print("========================= have learngene 1===============")
                    self.pre_weights = Parameter(gene_dict[self.serve_module_name]['weights'].to(dtype).to(self.device))
                    self.pre_biases = Parameter(gene_dict[self.serve_module_name]['biases'].to(dtype).to(self.device))

                    left_shape = tuple(normalized_shape)
                    factory_kwargs = {"requires_grad": True, 'device': device, 'dtype': dtype}
                    self.weights = Parameter(torch.empty(left_shape, **factory_kwargs))
                    self.biases = Parameter(torch.empty(left_shape, **factory_kwargs))
                         
                    p_list = torch.tensor(power_up_method(self._iblock, self.power, self.power_method)).to(self.device)
                    
                    self.weights.data = torch.matmul(p_list[:len(self.use_power_list)], self.pre_weights[self.use_power_list])
                    self.biases.data = torch.matmul(p_list[:len(self.use_power_list)], self.pre_biases[self.use_power_list])

                    self.pre_weights.requires_grad = self.pre_biases.requires_grad = False
                else:
                    left_shape = tuple(normalized_shape)
                    factory_kwargs = {"requires_grad": True, 'device': device, 'dtype': dtype}
                    p_list = torch.tensor(power_up_method(self._iblock, self.power, self.power_method)).to(self.device)

                    self.weights = Parameter(torch.empty(left_shape, **factory_kwargs))
                    self.biases = Parameter(torch.empty(left_shape, **factory_kwargs))

                    self.weights.data = torch.matmul(p_list[:1], self.weights)
                    self.biases.data = torch.matmul(p_list[:1], self.biases)
                
        else:
            if self.elementwise_affine:
                if gene_dict is not None:
                    print("========================= have learngene ===============")
                    self.pre_weights = Parameter(gene_dict[self.serve_module_name]['weights'].to(dtype).to(device))
                    self.pre_biases = Parameter(gene_dict[self.serve_module_name]['biases'].to(dtype).to(device))
                    self.pre_weights.requires_grad = self.pre_biases.requires_grad = False

                    pre_power = self.pre_weights.shape[0] - 1
                    print('The previous max power is ', pre_power, ' for', self.serve_module_name)
                    assert power > pre_power
                    left_shape = (power-pre_power,) + tuple(normalized_shape)
                    self.weights = Parameter(torch.empty(left_shape, **factory_kwargs))
                    self.biases = Parameter(torch.empty(left_shape, **factory_kwargs))
                else:
                    print("========================= not have learngene ===============")
                    self.weights = Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
                    self.biases = Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
            else:
                self.register_parameter('weights', None)
                self.register_parameter('biases', None)
            self.reset_parameters()

    def reset_parameters(self) -> None:
        if self.elementwise_affine:       # wether use 0 to init or split init
            if self.power == 0:
                init.ones_(self.weights[0])
                init.zeros_(self.weights[1:])
            else:
                print('Zero-init for ', self.serve_module_name, '\'s LELayerNorm')
                init.zeros_(self.weights)
            init.zeros_(self.biases)

    def build_weight_dict(self) -> dict:
        result = {self.serve_module_name:{}}
        result[self.serve_module_name]['weights'] = result[self.serve_module_name]['biases'] =  None
        if self.weights is not None:
            weights = self.weights.data
            biases = self.biases.data
            if self.pre_weights is not None and not self.learngene_d:
                weights = torch.cat((self.pre_weights.data, weights), 0)
                biases = torch.cat((self.pre_biases.data, biases), 0)
            result[self.serve_module_name]['weights'] = weights.cpu()
            result[self.serve_module_name]['biases'] =  biases.cpu()
        return result

    def forward(self, input: Tensor) -> Tensor:
        if not self.learngene_d:
            if self.weights is not None:
                p_list = torch.tensor(power_up_method(self._ilayer, self.power, self.power_method)).to(self.weights)
                if self.pre_weights is None:
                    weight = p_list[0] * self.weights[0]
                    bias = p_list[0] * self.biases[0]
                    if self.power > 0:
                        for i,lay in enumerate(p_list[1:]):
                            weight += lay * self.weights[i+1]
                            bias += lay * self.biases[i+1]
                else:
                    assert self.pre_weights.requires_grad == False
                    weights = torch.cat((self.pre_weights, self.weights), 0)
                    biases = torch.cat((self.pre_biases, self.biases), 0)
                    weight = torch.matmul(p_list, weights)
                    bias = torch.matmul(p_list, biases)
            else:
                weight = bias = None
            return F.layer_norm(input, self.normalized_shape[1:], weight, bias, self.eps)
        else:
            if self.constraint_d:
                p_list = torch.tensor(power_up_method(self._ilayer, self.power, self.power_method)).to(self.weights)
                weight = torch.matmul(p_list, self.weights)
                bias = torch.matmul(p_list, self.biases)
                return F.layer_norm(input, self.normalized_shape[1:], weight, bias, self.eps)


            elif self.only_rel_last:
                return F.layer_norm(
                    input, self.normalized_shape[1:], self.frozen_weight+self.unfrozen_weight, self.frozen_bias+self.unfrozen_bias, self.eps)
            else:
                return F.layer_norm(
                    input, self.normalized_shape[1:], self.weights, self.biases, self.eps)

    def extra_repr(self) -> str:
        return '{normalized_shape}, eps={eps}, ' \
            'elementwise_affine={elementwise_affine}'.format(**self.__dict__)

class LELinear(Module):
    __constants__ = ['in_features', 'out_features']
    in_features: int
    out_features: int
    weight: Tensor

    def __init__(self, iblock=0, serve_module_name='', in_features: int=0, out_features: int=0, bias: bool = True,
                 device=None, dtype=None, power = 5, power_method = 'power', gene_dict = None, use_power_list=[0,1], \
                 learngene_d=False, LELinear_trunc_normal_std = -1.0, only_rel_last=False, constraint_d=False) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super(LELinear, self).__init__()
        self.serve_module_name = serve_module_name
        print("serve_module_name: {} and iblock {}".format(self.serve_module_name, iblock))
        self.in_features = in_features
        self.out_features = out_features
        self.power = power
        self.power_method = power_method
        self.use_power_list = use_power_list
        self.learngene_d = learngene_d

        self.only_rel_last = only_rel_last
        self.constraint_d = constraint_d

        self.pre_weights, self.pre_biases = None, None
        self.weights = None
        self.bias = bias
        self._iblock = iblock

        self.device = device

        if self.learngene_d:
            if self.only_rel_last:
                assert gene_dict is not None
                print("========================= have learngene 2===============")
                self.pre_weights = Parameter(gene_dict[self.serve_module_name]['weights'].to(dtype).to(self.device))
                pre_power = self.pre_weights.shape[0] - 1
                assert pre_power == self.power
                if len(self.use_power_list) == 1:
                    left_shape = (out_features, in_features) 
                    factory_kwargs = {"requires_grad": True, 'device': device, 'dtype': dtype}
                    self.weights = Parameter(torch.empty(left_shape, **factory_kwargs))
                    p_list = torch.tensor(power_up_method(self._iblock, self.power, self.power_method)).to(self.device)
                    self.weights.data = torch.matmul(p_list[:len(self.use_power_list)], self.pre_weights[self.use_power_list].reshape(len(self.use_power_list), -1)).reshape(self.out_features, -1)
                else:
                    p_list = torch.tensor(power_up_method(self._iblock, self.power, self.power_method)).to(self.device)
                    left_shape = (out_features, in_features) 
                    factory_kwargs = {"requires_grad": True, 'device': device, 'dtype': dtype}
                    self.unfrozen_weight = Parameter(torch.empty(left_shape, **factory_kwargs))
                    print('this is for linear')
                    print(p_list[:len(self.use_power_list)-1].shape)
                    print(self.pre_weights[self.use_power_list][:-1].shape)
                    print(p_list[len(self.use_power_list)-1].shape)
                    print(self.pre_weights[self.use_power_list][-1].shape)

                    self.frozen_weight = torch.matmul(p_list[:len(self.use_power_list)-1], self.pre_weights[self.use_power_list[:-1]].reshape(len(self.use_power_list)-1, -1)).reshape(self.out_features, -1)
                    self.frozen_weight = self.frozen_weight.detach()

                    self.unfrozen_weight.data = p_list[len(self.use_power_list)-1] * self.pre_weights[self.use_power_list[-1]]
                    print(self.frozen_weight.shape)
                    print(self.unfrozen_weight.shape)
                self.pre_weights.requires_grad = False
            elif self.constraint_d:
                assert gene_dict is not None
                if gene_dict is not None:
                    print("========================= have learngene ===============")
                    self.weights = Parameter((gene_dict[self.serve_module_name]['weights'].cpu()[:self.power+1]).to(dtype).to(device))
                    pre_power = self.weights.shape[0] - 1
                    print('The previous max power is ', pre_power, ' for', self.serve_module_name, 'and now we use it with constraint!')
                    assert power == pre_power
                else:
                    print("========================= not have learngene ===============")
                    print('have constraint and is for downsteam')
                    self.weights = Parameter(torch.empty((power+1, out_features, in_features), **factory_kwargs)) # (12, in, out)
            else:
                if gene_dict is not None:
                    print("========================= have learngene 1===============")
                    self.pre_weights = Parameter(gene_dict[self.serve_module_name]['weights'].to(dtype).to(device))

                    left_shape = (out_features, in_features) 
                    factory_kwargs = {"requires_grad": True, 'device': device, 'dtype': dtype}

                    self.weights = Parameter(torch.empty(left_shape, **factory_kwargs))
                    
                    p_list = torch.tensor(power_up_method(self._iblock, self.power, self.power_method)).to(self.pre_weights)
                    self.weights.data = torch.matmul(p_list[:len(self.use_power_list)], self.pre_weights[self.use_power_list].reshape(len(self.use_power_list), -1)).reshape(self.out_features, -1)
                    self.pre_weights.requires_grad = False
                else:
                    left_shape = (out_features, in_features) 
                    factory_kwargs = {"requires_grad": True, 'device': device, 'dtype': dtype}

                    self.weights = Parameter(torch.randn(left_shape, **factory_kwargs))
        else:
            if gene_dict is not None:
                print("========================= have learngene ===============")
                self.pre_weights = Parameter(gene_dict[self.serve_module_name]['weights'].to(dtype).to(device))
                self.pre_weights.requires_grad = False

                pre_power = self.pre_weights.shape[0] - 1
                print('The previous max power is ', pre_power, ' for', self.serve_module_name)
                assert power > pre_power
                left_shape = (power-pre_power, out_features, in_features)
                self.weights = Parameter(torch.empty(left_shape, **factory_kwargs))
            else:
                print("========================= not have learngene ===============")
                self.weights = Parameter(torch.empty((power+1, out_features, in_features), **factory_kwargs)) # (12, in, out)

        if bias:
            factory_kwargs = {"requires_grad": True, 'device': device, 'dtype': dtype}
            if gene_dict is not None:
                assert gene_dict[self.serve_module_name]['biases'] is not None
                print('============================', gene_dict[self.serve_module_name]['biases'].shape)
                if self.learngene_d:
                    if self.only_rel_last:
                        assert gene_dict is not None
                        print("========================= have learngene 2===============")
                        self.pre_biases = Parameter(gene_dict[self.serve_module_name]['biases'].to(dtype).to(self.device))
                        pre_power = self.pre_biases.shape[0] - 1
                        assert pre_power == self.power
                        if len(self.use_power_list) == 1:
                            left_shape = (out_features) 
                            factory_kwargs = {"requires_grad": True, 'device': device, 'dtype': dtype}
                            self.biases = Parameter(torch.empty(left_shape, **factory_kwargs))
                            p_list = torch.tensor(power_up_method(self._iblock, self.power, self.power_method)).to(self.device)
                            self.biases.data = torch.matmul(p_list[:len(self.use_power_list)], self.pre_biases[self.use_power_list])
                        else:
                            p_list = torch.tensor(power_up_method(self._iblock, self.power, self.power_method)).to(self.device)
                            left_shape = (out_features) 
                            factory_kwargs = {"requires_grad": True, 'device': device, 'dtype': dtype}
                            self.unfrozen_bias = Parameter(torch.empty(left_shape, **factory_kwargs))
                            print('this is for linear')

                            self.frozen_bias = torch.matmul(p_list[:len(self.use_power_list)-1], self.pre_biases[self.use_power_list[:-1]])
                            # self.frozen_bias.requires_grad = self.frozen_bias.requires_grad = False

                            self.frozen_bias = self.frozen_bias.detach()

                            self.unfrozen_bias.data = p_list[len(self.use_power_list)-1] * self.pre_biases[self.use_power_list[-1]]
                            print(self.frozen_bias.shape)
                            print(self.unfrozen_bias.shape)
                        self.pre_biases.requires_grad = False
                    elif self.constraint_d:
                        assert gene_dict is not None
                        self.biases = Parameter((gene_dict[self.serve_module_name]['biases'].cpu()[:self.power+1]).to(dtype).to(device))
                        pre_power = self.biases.shape[0] - 1
                        assert power == pre_power
                    else:
                        print("=============linear bias learngene")
                        self.pre_biases = Parameter(gene_dict[self.serve_module_name]['biases'].to(dtype).to(device))
                        p_list = torch.tensor(power_up_method(self._iblock, self.power, self.power_method)).to(self.pre_weights)
                        self.biases = Parameter(torch.empty(1, out_features, **factory_kwargs))
                        self.biases.data = torch.matmul(p_list[:len(self.use_power_list)], self.pre_biases[self.use_power_list].reshape(len(self.use_power_list), -1)).reshape(-1)
                        
                        self.pre_biases.requires_grad = False
                else:
                    self.pre_biases = Parameter(gene_dict[self.serve_module_name]['biases'].to(dtype).to(device))
                    self.pre_biases.requires_grad = False

                    pre_power = self.pre_biases.shape[0] - 1
                    assert power > pre_power
                    self.biases = Parameter(torch.empty(power-pre_power, out_features, **factory_kwargs))
            else:
                    if self.learngene_d:
                        if self.constraint_d:
                            self.biases = Parameter(torch.empty(power+1, out_features, **factory_kwargs))
                        else:
                            self.biases = Parameter(torch.randn(out_features, **factory_kwargs))
                    else:
                        self.biases = Parameter(torch.empty(power+1, out_features, **factory_kwargs))
        else:
            self.register_parameter('biases', None)
        self.LELinear_trunc_normal_std = LELinear_trunc_normal_std


    def reset_parameters(self) -> None:
        
        if self.power == 0:
            trunc_normal_(self.weights, std=.02)
        else:
            if self.LELinear_trunc_normal_std >= 0:
                print('No zero-init for ', self.serve_module_name, '\'s LELinear and the star_ilayert acc would be low, the std is ', self.LELinear_trunc_normal_std)
                trunc_normal_(self.weights, std=self.LELinear_trunc_normal_std)
            else:
                print('Zero-init for ', self.serve_module_name, '\'s LELinear and the start acc would be high')
                init.zeros_(self.weights)
        if self.biases is not None:
            init.zeros_(self.biases)

    def build_weight_dict(self) -> dict:
        result = {self.serve_module_name:{}}
        result[self.serve_module_name]['weights'] = result[self.serve_module_name]['biases'] =  None

        weights = self.weights.data
        if self.pre_weights is not None and not self.learngene_d:
            weights = torch.cat((self.pre_weights.data, weights), 0)
        result[self.serve_module_name]['weights'] = weights.cpu()

        if self.biases is not None and not self.learngene_d:
            biases = self.biases.data
            if self.pre_biases is not None:
                biases = torch.cat((self.pre_biases.data, biases), 0)
            result[self.serve_module_name]['biases'] =  biases.cpu()
        return result

    def forward(self, input: Tensor) -> Tensor:
        weight = bias = None

        if not self.learngene_d:
            p_list = torch.tensor(power_up_method(self._ilayer, self.power, self.power_method)).to(self.weights)
            if self.pre_weights is None:
                weight = p_list[0] * self.weights[0]
                if self.biases is not None:
                    bias = p_list[0] * self.biases[0]
                if self.power > 0:
                    for i,lay in enumerate(p_list[1:]):
                        weight += self.weights[i+1]*lay
                        if self.biases is not None:
                            bias += self.biases[i+1]*lay
            else:
                assert self.pre_weights.requires_grad == False
                if self.biases is not None:
                    assert self.pre_biases is not None
                    weights = torch.cat((self.pre_weights, self.weights), 0)
                    biases = torch.cat((self.pre_biases, self.biases), 0)
                    weight = torch.matmul(p_list, weights.reshape(self.power+1, -1))
                    weight = weight.reshape(self.out_features, -1)
                    bias = torch.matmul(p_list, biases.reshape(self.power+1, -1))
                    bias = bias.reshape(-1)
                else:
                    weights = torch.cat((self.pre_weights, self.weights), 0)
                    weight = torch.matmul(p_list, weights.reshape(self.power+1, -1))
                    weight = weight.reshape(self.out_features, -1)

            return F.linear(input, weight, bias)
        else:
            if self.constraint_d:
                p_list = torch.tensor(power_up_method(self._ilayer, self.power, self.power_method)).to(self.weights)
                weight = torch.matmul(p_list, self.weights.reshape(self.power+1, -1))
                weight = weight.reshape(self.out_features, -1)
                if self.biases is not None:
                    bias = torch.matmul(p_list, self.biases.reshape(self.power+1, -1))
                    bias = bias.reshape(-1)
                return F.linear(input, weight, bias)

            if self.only_rel_last and self.bias:
                return F.linear(input, self.frozen_weight+self.unfrozen_weight, self.frozen_bias+self.unfrozen_bias)
            elif self.only_rel_last and not self.bias:
                return F.linear(input, self.frozen_weight+self.unfrozen_weight, self.biases)
            else:
                return F.linear(input, self.weights, self.biases)

    def extra_repr(self) -> str:
        return 'in_features={}, out_features={}, bias={}'.format(
            self.in_features, self.out_features, self.biases is not None
        )