import torch
from function_transformer_attention import SpGraphTransAttentionLayer
from base_classes import ODEblock
from utils import get_rw_adj
from block_fractional_euler import *
# from block_fractional_euler import caputoEuler
class AttODEblock_FRAC(ODEblock):
  def __init__(self, odefunc, regularization_fns, opt, data, device, t=torch.tensor([0, 1]), gamma=0.5):
    super(AttODEblock_FRAC, self).__init__(odefunc, regularization_fns, opt, data, device, t)

    self.odefunc = odefunc(self.aug_dim * opt['hidden_dim'], self.aug_dim * opt['hidden_dim'], opt, data, device)
    # self.odefunc.edge_index, self.odefunc.edge_weight = data.edge_index, edge_weight=data.edge_attr
    edge_index, edge_weight = get_rw_adj(data.edge_index, edge_weight=data.edge_attr, norm_dim=1,
                                         fill_value=opt['self_loop_weight'],
                                         num_nodes=data.num_nodes,
                                         dtype=data.x.dtype)
    self.odefunc.edge_index = edge_index.to(device)
    self.odefunc.edge_weight = edge_weight.to(device)
    self.reg_odefunc.odefunc.edge_index, self.reg_odefunc.odefunc.edge_weight = self.odefunc.edge_index, self.odefunc.edge_weight

    if opt['adjoint']:
      from torchdiffeq import odeint_adjoint as odeint
    else:
      from torchdiffeq import odeint
    self.train_integrator = odeint
    self.test_integrator = odeint
    self.set_tol()
    # parameter trading off between attention and the Laplacian
    self.multihead_att_layer = SpGraphTransAttentionLayer(opt['hidden_dim'], opt['hidden_dim'], opt,
                                                          device, edge_weights=self.odefunc.edge_weight).to(device)
    self.device = device
    self.opt = opt

  def get_attention_weights(self, x):
    attention, values = self.multihead_att_layer(x, self.odefunc.edge_index)
    return attention

  def forward(self, x):
    if 'graphconterm' in self.opt['function']:
      x_all = x
      x1 = x_all[:, :self.opt['hidden_dim']]
      t = self.t.type_as(x1)
      self.odefunc.attention_weights = self.get_attention_weights(x1)
      self.reg_odefunc.odefunc.attention_weights = self.odefunc.attention_weights
    elif 'graphcon' in self.opt['function']:
      x_all = x
      x1 = x_all[:, :self.opt['hidden_dim']]
      y = x_all[:, self.opt['hidden_dim']:]
      t = self.t.type_as(x1)
      self.odefunc.attention_weights = self.get_attention_weights(y)
      self.reg_odefunc.odefunc.attention_weights = self.odefunc.attention_weights
    elif 'term' in self.opt['function']:
      x_all = x
      x1 = x_all[:, :self.opt['hidden_dim']]
      t = self.t.type_as(x1)
      self.odefunc.attention_weights = self.get_attention_weights(x1)
      self.reg_odefunc.odefunc.attention_weights = self.odefunc.attention_weights

    else:


      t = self.t.type_as(x)
      self.odefunc.attention_weights = self.get_attention_weights(x)
      self.reg_odefunc.odefunc.attention_weights = self.odefunc.attention_weights
    integrator = self.train_integrator if self.training else self.test_integrator

    # reg_states = tuple(torch.zeros(x.size(0)).to(x) for i in range(self.nreg))
    #
    # func = self.reg_odefunc if self.training and self.nreg > 0 else self.odefunc
    # state = (x,) + reg_states if self.training and self.nreg > 0 else x

    func = self.odefunc
    state = x

    # set the alpha value for the fractional derivative
    if "graphconterm" in self.opt['function']:
      gamma = 2 / self.opt['num_terms']
      print("gamma: ", gamma)

      alpha = torch.tensor(gamma)
    elif "term" in self.opt['function']:
      gamma = 1 / self.opt['num_terms']
      print("gamma: ", gamma)

      alpha = torch.tensor(gamma)
    else:
      alpha = torch.tensor(self.opt['alpha_ode'])

    # if alpha > 1:
    #     raise ValueError("alpha_ode must be in (0,1)")

    if alpha > 1:
      state = state + x
      if self.opt['method'] == "ceuler":
        z = caputoEuler(alpha, func, state, tspan=torch.arange(0, self.opt['time'], self.opt['step_size']),
                        device=self.device)
      elif self.opt['method'] == "ceuler_corrector":
        z = caputoEuler_corrector(alpha, func, state, tspan=torch.arange(0, self.opt['time'], self.opt['step_size']),
                                  device=self.device)
      elif self.opt['method'] == "GL":
        z = GL_method(alpha, func, state, tspan=torch.arange(0, self.opt['time'], self.opt['step_size']),
                      device=self.device)
      elif self.opt['method'] == "memory":
        z = caputoEuler_memory(alpha, func, state, tspan=torch.arange(0, self.opt['time'], self.opt['step_size']),
                               device=self.device, memory_k=self.opt['memory_k'])
      elif self.opt['method'] == "implicit":
        z = implicit_l1(alpha, func, state, tspan=torch.arange(0, self.opt['time'], self.opt['step_size']),device=self.device)
      elif self.opt['method'] == "trap":
        z = product_trap(alpha, func, state, tspan=torch.arange(0, self.opt['time'], self.opt['step_size']),
                         device=self.device)
      else:
        raise ValueError("Method not implemented")

    else:
      if self.opt['method'] == "ceuler":
        z = caputoEuler(alpha, func, state, tspan=torch.arange(0, self.opt['time'], self.opt['step_size']),
                        device=self.device)
      elif self.opt['method'] == "ceuler_corrector":
        z = caputoEuler_corrector(alpha, func, state, tspan=torch.arange(0, self.opt['time'], self.opt['step_size']),
                                  device=self.device)
      elif self.opt['method'] == "GL":
        z = GL_method(alpha, func, state, tspan=torch.arange(0, self.opt['time'], self.opt['step_size']),
                      device=self.device)
      elif self.opt['method'] == "implicit":
        z = implicit_l1(alpha, func, state, tspan=torch.arange(0, self.opt['time'], self.opt['step_size']),device=self.device)
      elif self.opt['method'] == "memory":
        z = caputoEuler_memory(alpha, func, state, tspan=torch.arange(0, self.opt['time'], self.opt['step_size']),
                               device=self.device, memory_k=self.opt['memory_k'])
      elif self.opt['method'] == "PIEX":
        z = PIEX_method(alpha, func, state, tspan=torch.arange(0, self.opt['time'], self.opt['step_size']),
                        device=self.device)
      elif self.opt['method'] == "PIIM":
        z = PIIM_method(alpha, func, state, tspan=torch.arange(0, self.opt['time'], self.opt['step_size']),
                        device=self.device)
      elif self.opt['method'] == "PIIM_trap":
        z = PIIM_trap_method(alpha, func, state, tspan=torch.arange(0, self.opt['time'], self.opt['step_size']),
                             device=self.device)
      elif self.opt['method'] == "trap":
        z = product_trap(alpha, func, state, tspan=torch.arange(0, self.opt['time'], self.opt['step_size']),
                         device=self.device)
      else:
        raise ValueError("Method not implemented")

    return z

    # if self.opt["adjoint"] and self.training:
    #   state_dt = integrator(
    #     func, state, t,
    #     method=self.opt['method'],
    #     options={'step_size': self.opt['step_size']},
    #     adjoint_method=self.opt['adjoint_method'],
    #     adjoint_options={'step_size': self.opt['adjoint_step_size']},
    #     atol=self.atol,
    #     rtol=self.rtol,
    #     adjoint_atol=self.atol_adjoint,
    #     adjoint_rtol=self.rtol_adjoint)
    # else:
    #   state_dt = integrator(
    #     func, state, t,
    #     method=self.opt['method'],
    #     options={'step_size': self.opt['step_size']},
    #     atol=self.atol,
    #     rtol=self.rtol)
    #
    # if self.training and self.nreg > 0:
    #   z = state_dt[0][1]
    #   reg_states = tuple(st[1] for st in state_dt[1:])
    #   return z, reg_states
    # else:
    #   z = state_dt[1]
    #   return z

  def __repr__(self):
    return self.__class__.__name__ + '( Time Interval ' + str(self.t[0].item()) + ' -> ' + str(self.t[1].item()) \
           + ")"


class AttODEblock_PLOT(ODEblock):
  def __init__(self, odefunc, regularization_fns, opt, data, device, t=torch.tensor([0, 1]), gamma=0.5):
    super(AttODEblock_PLOT, self).__init__(odefunc, regularization_fns, opt, data, device, t)

    self.odefunc = odefunc(self.aug_dim * opt['hidden_dim'], self.aug_dim * opt['hidden_dim'], opt, data, device)
    # self.odefunc.edge_index, self.odefunc.edge_weight = data.edge_index, edge_weight=data.edge_attr
    edge_index, edge_weight = get_rw_adj(data.edge_index, edge_weight=data.edge_attr, norm_dim=1,
                                         fill_value=opt['self_loop_weight'],
                                         num_nodes=data.num_nodes,
                                         dtype=data.x.dtype)
    self.odefunc.edge_index = edge_index.to(device)
    self.odefunc.edge_weight = edge_weight.to(device)
    self.reg_odefunc.odefunc.edge_index, self.reg_odefunc.odefunc.edge_weight = self.odefunc.edge_index, self.odefunc.edge_weight

    if opt['adjoint']:
      from torchdiffeq import odeint_adjoint as odeint
    else:
      from torchdiffeq import odeint
    self.train_integrator = odeint
    self.test_integrator = odeint
    self.set_tol()
    # parameter trading off between attention and the Laplacian
    self.multihead_att_layer = SpGraphTransAttentionLayer(opt['hidden_dim'], opt['hidden_dim'], opt,
                                                          device, edge_weights=self.odefunc.edge_weight).to(device)

  def get_attention_weights(self, x):
    attention, values = self.multihead_att_layer(x, self.odefunc.edge_index)
    return attention

  def forward(self, x):
    t = self.t.type_as(x)
    self.odefunc.attention_weights = self.get_attention_weights(x)
    # attention_weights = self.get_attention_weights(x)
    mean_att = self.odefunc.attention_weights.mean(dim=1, keepdim=False)
    self.reg_odefunc.odefunc.attention_weights = self.odefunc.attention_weights
    integrator = self.train_integrator if self.training else self.test_integrator

    reg_states = tuple(torch.zeros(x.size(0)).to(x) for i in range(self.nreg))

    func = self.reg_odefunc if self.training and self.nreg > 0 else self.odefunc
    state = (x,) + reg_states if self.training and self.nreg > 0 else x

    if self.opt["adjoint"] and self.training:
      state_dt = integrator(
        func, state, t,
        method=self.opt['method'],
        options={'step_size': self.opt['step_size']},
        adjoint_method=self.opt['adjoint_method'],
        adjoint_options={'step_size': self.opt['adjoint_step_size']},
        atol=self.atol,
        rtol=self.rtol,
        adjoint_atol=self.atol_adjoint,
        adjoint_rtol=self.rtol_adjoint)
    else:
      state_dt = integrator(
        func, state, t,
        method=self.opt['method'],
        options={'step_size': self.opt['step_size']},
        atol=self.atol,
        rtol=self.rtol)

    if self.training and self.nreg > 0:
      z = state_dt[0][1]
      reg_states = tuple(st[1] for st in state_dt[1:])
      return z, reg_states
    else:
      z = state_dt[1]
      return z,mean_att,self.odefunc.edge_index

  def __repr__(self):
    return self.__class__.__name__ + '( Time Interval ' + str(self.t[0].item()) + ' -> ' + str(self.t[1].item()) \
           + ")"
