import torch


class DQN_SMAC:

    def __init__(self, n_agents, ob_dim, st_dim, ac_dim, args, h_dim=256):
        self.ob_dim = ob_dim
        self.st_dim = st_dim
        self.ac_dim = ac_dim
        self.h_dim = h_dim
        self.n_agents = n_agents
        self.max_train_steps = args.max_train_steps
        self.lr = args.lr
        self.gamma = args.gamma
        self.batch_size = args.batch_size
        self.target_update_freq = args.target_update_freq
        self.tau = args.tau
        self.device = args.device
        self.input_dim = self.ob_dim + self.ac_dim + self.n_agents
        self.train_step = 0
        self.max_grad_norm = 0.5

    def choose_action(self, obs, last_onehot, avails, evaluate=False):
        inputs = torch.cat([obs, last_onehot, torch.eye(self.n_agents)], -1)
        if evaluate:
            actions = self.qmix.mode(inputs, avails)
        else:
            actions = self.qmix.sample(inputs, avails)
        return actions

    def update_targets(self, soft_update=False):
        if self.train_step % self.target_update_freq == 0:
            if soft_update:
                for param, target_param in zip(self.qmix.eval_Q_net.parameters(), self.qmix.target_Q_net.parameters()):
                    target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
                for param, target_param in zip(self.qmix.eval_mix_net.parameters(), self.qmix.target_mix_net.parameters()):
                    target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
            else:
                self.qmix.target_Q_net.load_state_dict(self.qmix.eval_Q_net.state_dict())
                self.qmix.target_mix_net.load_state_dict(self.qmix.eval_mix_net.state_dict())