import torch
import torch.nn as nn


class CoachExecutorModel(nn.Module):
    def __init__(self, coach, executor, cheat, inst_mode):
        super().__init__()
        self.coach = coach
        self.executor = executor
        self.cheat = cheat
        self.inst_mode = inst_mode
        assert self.executor.inst_dict._idx2inst == self.coach.inst_dict._idx2inst
        # TODO: make sure that max_num_prev_cmds are the same for
        # coach and executor during training
        self.executor.train(False)

    @property
    def sampler(self):
        return self.coach.sampler

    def train(self, training):
        assert not self.executor.training
        self.coach.train(training)

    # def get_input(self, batch):
    #     return batch

    def get_action(self, batch):
        a = {
            'cont': batch['cont'],
            'inst': batch['inst']
        }
        return a

    def get_policy(self, batch):
        pi = {
            'cont_pi': batch['cont_pi'],
            'inst_pi': batch['inst_pi'],
        }
        # return batch['pi']
        return pi

    def get_reward(self, batch):
        return batch['reward']

    def get_value(self, batch):
        return batch['v']

    def get_terminal(self, batch):
        return batch['terminal']

    def act(self, batch):
        # assert not self.coach.training
        assert not self.executor.training

        if self.cheat:
            coach_input = self.coach.format_coach_input(batch, 'nofow_')
        else:
            coach_input = self.coach.format_coach_input(batch)
        inst, inst_len, inst_cont, coach_reply = self.coach.sample(coach_input, self.inst_mode)

        executor_input = self.executor.format_executor_input(
            batch, inst, inst_len, inst_cont)
        executor_reply = self.executor.compute_prob(executor_input)

        coach_reply.update(executor_reply)
        coach_reply['num_unit'] = batch['num_army']
        return coach_reply

    def forward(self, batch):
        if self.cheat:
            batch = self.coach.format_coach_input(batch, 'nofow_')
        else:
            batch = self.coach.format_coach_input(batch)
        return self.coach.rl_forward(batch, self.inst_mode)
