import torch
import torch.nn as nn
import torch.jit as jit
from network.utils import Linear


class QMIX_Net(jit.ScriptModule):

    def __init__(self, n_agents, st_dim, h_dim, activation="elu"):
        super().__init__()
        self.h_dim = h_dim
        self.use_abs = True
        if activation == "elu":
            self.act_fn = nn.ELU()
        elif activation == "relu":
            self.act_fn = nn.ReLU()
        elif activation == "tanh":
            self.act_fn = nn.Tanh()
            # self.use_abs = False
        elif activation == "sigmoid":
            self.act_fn = nn.Sigmoid()
            # self.use_abs = False
        self.hyper_w1 = Linear(st_dim, n_agents * h_dim)
        self.hyper_w2 = Linear(st_dim, h_dim)
        self.hyper_b1 = Linear(st_dim, h_dim)
        self.hyper_b2 = nn.Sequential(Linear(st_dim, h_dim), nn.ReLU(), Linear(h_dim, 1))

    @jit.script_method
    def forward(self, q, s) -> torch.Tensor:
        batch_size, _, st_dim = s.shape
        n_agents = q.size(-1)
        q = q.reshape(-1, 1, n_agents)
        s = s.reshape(-1, st_dim)
        w1 = self.hyper_w1(s)
        if self.use_abs:
            w1 = torch.abs(w1)
        b1 = self.hyper_b1(s)
        w1 = w1.view(-1, n_agents, self.h_dim)
        b1 = b1.view(-1, 1, self.h_dim)
        q_hidden = self.act_fn(torch.bmm(q, w1) + b1)
        w2 = self.hyper_w2(s)
        if self.use_abs:
            w2 = torch.abs(w2)
        b2 = self.hyper_b2(s)
        w2 = w2.view(-1, self.h_dim, 1)
        b2 = b2.view(-1, 1, 1)
        q_total = torch.bmm(q_hidden, w2) + b2
        q_total = q_total.view(batch_size, -1, 1)
        return q_total