import torch
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
import math
import torch.optim as optim
import numpy as np

__all__ = ['relu_net', 'quadratic_net', 'plain_quadratic']


def gauss_norm(M):
    term_a = (M+M.t()).norm()**2/4
    term_b = M.trace()**2
    return torch.sqrt(2*term_a+term_b)

def cube_norm(M):
    term_a = (M+M.t()).norm()**2/4
    term_b = M.trace()**2
    term_c = (M**2).trace()
    return torch.sqrt(2*term_a+term_b -2*term_c)


class relu_net(nn.Module):
    def __init__(self, D = 40, normalization_method = False, M=None, full_trainset=None, **kwargs):
        super(relu_net, self).__init__()
        self.M = D if M is None else M
        self.D = D
        self.U = nn.Parameter(torch.randn(self.D,self.M),requires_grad = True)
        self.V = nn.Parameter(torch.randn(self.D,self.M),requires_grad = True)
        self.normalization_method = normalization_method
        self.full_trainset = full_trainset
        
    def _forward_no_normalization(self,x):
        output =  nn.functional.relu(x @ (self.U)).sum(dim=1) - nn.functional.relu(x @ (self.V)).sum(dim=1)
        output/= np.sqrt(self.M)
        return output
        
    def forward(self,x,normalization_method=None):
        if normalization_method is None:
            normalization_method = self.normalization_method
        output = self._forward_no_normalization(x)
        if self.normalization_method == 'BN':
            return output/output.norm()*math.sqrt(len(output))
        elif self.normalization_method == 'FBN': #full batch normalization
            return output/self._forward_no_normalization(self.full_trainset).norm()*math.sqrt(len(self.full_trainset))
        elif self.normalization_method == None: #No batch normalization
            return output
        else:
            raise AttributeError(f"Unknown task type {self.normalization_method}")
    
    
class quadratic_net(nn.Module):
    def __init__(self, D = 40, normalization_method = None, M=None, gauss=None, full_trainset=None, **kwargs):
        super(quadratic_net, self).__init__()
        self.M = D if M is None else M
        self.D = D
        
        def weight_init(D,M):
            if M <= D:
                return torch.randn(D,M)
            original = torch.randn(D,M)
            U,S,_ = torch.svd(original @ original.t())
            return torch.svd(original @ original.t())[0] @torch.diag(torch.sqrt(S))


        self.U = nn.Parameter(weight_init(self.D,self.M),requires_grad = True)
        self.V = nn.Parameter(weight_init(self.D,self.M),requires_grad = True)
        self.normalization_method = normalization_method
        self.full_trainset = full_trainset
        if normalization_method =='PN':
            normalization_method = 'PN_gauss' if gauss else "PN_cube"
            
#         self.population_norm = gauss_norm if gauss else cube_norm

    def _forward_no_normalization(self,x):
        output =  (x.matmul(self.U)**2 - x.matmul(self.V)**2).sum(1)
        output/= np.sqrt(self.M)
        return output
    
    def forward(self,x,normalization_method=None):
        if normalization_method is None:
            normalization_method = self.normalization_method
        output = self._forward_no_normalization(x)
        if normalization_method == 'BN':
            return output/output.norm()*math.sqrt(len(output))
        elif normalization_method == 'FBN': #full batch normalization
            return output/self._forward_no_normalization(self.full_trainset).norm()*math.sqrt(len(self.full_trainset))
        elif normalization_method == 'PN_gauss': # gauss population normalization normalization
            return output/gauss_norm(self.equiv_mat())*np.sqrt(self.M)
        elif normalization_method == 'PN_cube': # gauss population normalization normalization
            return output/cube_norm(self.equiv_mat())*np.sqrt(self.M)
        elif normalization_method == None: #No batch normalization
            return output
        else:
            raise AttributeError(f"Unknown task type {self.normalization_method}")
        
    def equiv_mat(self):
        return self.U.matmul(self.U.t()) - self.V.matmul(self.V.t())
    
    
    def set_gt(self):
        del self.U
        del self.V
        tmp_a = torch.zeros(self.D,self.D)
        tmp_a[0,0] = 0.5
        tmp_a[1,0] = 0.5
        tmp_b = torch.zeros(self.D,self.D)
        tmp_b[0,0] = 0.5
        tmp_b[1,0] = -0.5
        self.U = nn.Parameter(tmp_a,requires_grad = True)
        self.V = nn.Parameter(tmp_b,requires_grad = True)
        
        
        
class plain_quadratic(nn.Module):
    def __init__(self, D = 40, normalization_method = False,**kwargs):
        super(plain_quadratic, self).__init__()
        self.D = D
        self.W = nn.Parameter(torch.randn(self.D,self.D),requires_grad = True)
    def forward(self,x,norm_fn=gauss_norm):
        output =  (x.matmul(self.W) * x).sum(1)
        if not self.normalization_method:  #normalization with distribution
            return output/ norm_fn(self.equiv_mat())
        else:
            return output/output.norm()*math.sqrt(len(output))
        
    def equiv_mat(self):
        return self.W/2+ self.W.t()/2
    
    def set_gt(self):
        del self.W
        tmp = torch.zeros(self.D,self.D)
        tmp[0,1] = 0.5
        tmp[1,0] = 0.5
        self.W = nn.Parameter(tmp,requires_grad = True)    
        

