import torch
from manifold_flow.utils import various

class ELCD(torch.nn.Module):
    def __init__(self, d, x_eq = None, hidden_dim=10, train_x_eq=False):
        super().__init__()
        self.d = d

        if x_eq is not None:
            self.x_eq = x_eq
            self.x_eq.requires_grad = False

        self.build_model(d, hidden_dim)
    
    def build_model(self, d, hidden_dim=10):
        
        self.P_s = torch.nn.Sequential(
            torch.nn.Linear(d, hidden_dim),
            torch.nn.Tanh(),
            torch.nn.Linear(hidden_dim, d **2),
        )

        self.P_a = torch.nn.Sequential(
            torch.nn.Linear(d, hidden_dim),
            torch.nn.Tanh(),
            torch.nn.Linear(hidden_dim, d **2),
        )

        # initialize the parameters

        self.eps = .01

        # randomly initalize parameters
        self.P_s.apply(self.init_weights)
        self.P_a.apply(self.init_weights)
        


    def set_x_eq(self, x_eq):

        assert x_eq.shape[1] == self.d
        assert x_eq.shape[0]  == 1
        with torch.no_grad():
            self.x_eq = torch.tensor(x_eq)

    def init_weights(self, m):
        if type(m) == torch.nn.Linear:
            torch.nn.init.xavier_uniform_(m.weight)
            m.bias.data.fill_(0.01)

    def get_A_matrix(self, x):
        b = x.shape[0]
        p_s = self.P_s(x).reshape(b, self.d, self.d)
        p_s_T = p_s.permute(0,2,1)
        
        p_a = self.P_a(x).reshape(b, self.d, self.d)
        p_a_T = p_a.permute(0,2,1)
        I = I = torch.eye(self.d).unsqueeze(0).repeat(b,1,1).to(x.device)
        A = -torch.bmm(p_s_T, p_s)  - self.eps * I + p_a - p_a_T

        if torch.isnan(A).any():
            print("x", x)
            print("p_s", p_s)
            print("p_a", p_a)
            print("A", A)
        return A

    def forward(self, x, x_eq = None):
        assert x.shape[1] == self.d


        A = self.get_A_matrix(x)


        if x_eq is None:
            x_eq = self.x_eq

        x_eq = x_eq.repeat(x.shape[0], 1).to(x.device)
        assert x_eq.shape[0] == x.shape[0]
        assert x_eq.shape[1] == x.shape[1]
        

        mat = torch.bmm(A, (x-x_eq).unsqueeze(-1)).squeeze(-1)

        if torch.isnan(mat).any():
            print("x", x)
            print("x_eq", x_eq)
            print("A", A)
            print("mat", mat)
        return mat
    
    def forward_discrete(self, x, dt=1, x_eq=None):
        # evolves x forward by dt with first order euler method
        dx_dt =  self.forward(x, x_eq)
        return x + dx_dt * dt
    
class ELCD_Simple(ELCD):    
    def build_model(self, d, hidden_dim=10):
        self.P_s = torch.nn.Parameter(torch.zeros(d, d))
        self.P_a = torch.nn.Parameter(torch.zeros(d, d))

        # initialize the parameters
        torch.nn.init.xavier_uniform_(self.P_s)
        torch.nn.init.xavier_uniform_(self.P_a)


        self.eps = .01

    def get_A_matrix(self, x):
        I = torch.eye(self.d).to(x.device)
        A = -self.P_s.T @ self.P_s  - self.eps * I + self.P_a - self.P_a.T
        return A
    
    def forward(self, x, x_eq = None):
        assert x.shape[1] == self.d


        A = self.get_A_matrix(x).unsqueeze(0).repeat(x.shape[0], 1, 1)
      
        
        if x_eq is None:
            x_eq = self.x_eq
        else:
            assert x_eq.shape[1] == self.d
            assert x_eq.shape[0]  == 1
        x_eq = x_eq.repeat(x.shape[0], 1)
        
        return torch.bmm(A,  (x-x_eq).unsqueeze(-1)).squeeze(-1)




    


class ELCD_Transform(torch.nn.Module):
    def __init__(self, x_eq, transform, model):
        super().__init__()
        self.transform = transform
        
        self.model = model
        self.set_x_eq(x_eq)
        print("Eq point:", self.x_eq)

    def to(self, device):
        super().to(device)
        self.x_eq = self.x_eq.to(device)
        self.model.to(device)
        self.transform.to(device)
        return self


    def forward(self, x):
        # Returns approximation of x_

        x_eq = self.transform(self.x_eq)[0]
        # print(x_eq.shape)
        
        z = self.encode(x)

        z_dot_pred = self.model(z, x_eq)
        res, inv_jacobian = self.decode(z)

        z_dot_pred = z_dot_pred.view(z.shape[0], z.shape[1], 1)



        x_dot_pred = torch.bmm(inv_jacobian, z_dot_pred).squeeze(-1)

        return x_dot_pred
    
    def forward_discrete(self, x, dt=1):
        x_dot = self.forward(x)
        next_x = x + x_dot * dt
        return next_x

    def encode(self, x):
        h, _ = self.transform(x, full_jacobian=False)
        return h
    
    def decode(self, z, full_jacobian=True):
        return self.transform.inverse(z, full_jacobian=full_jacobian)
    
    def set_x_eq(self, x_eq):
        print("setting x_eq")
        self.x_eq = x_eq
        self.x_eq.requires_grad = False
        self.model.set_x_eq(self.transform(x_eq)[0]) #TODO Untested
    


