from pettingzoo.mpe import simple_push_v3,simple_tag_v3
import random
import argparse
import itertools
import torch
import numpy as np
import itertools
import datetime
import time
from sac.sac import SAC
from torch.utils.tensorboard import SummaryWriter
from sac.replay_memory import ReplayMemory

parser = argparse.ArgumentParser(description='PyTorch Soft Actor-Critic Args')
parser.add_argument('--env_name', default="simple_push_v3",
                    help='environment:simple_push_v3|simple_tag_v3 (default: simple_push_v3)')
parser.add_argument('--policy', default="Gaussian",
                    help='Policy Type: Gaussian | Deterministic (default: Gaussian)')
parser.add_argument('--eval', type=bool, default=False,
                    help='Evaluates a policy a policy every 10 episode (default: True)')
parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
                    help='discount factor for reward (default: 0.99)')
parser.add_argument('--tau', type=float, default=0.005, metavar='G',
                    help='target smoothing coefficient(τ) (default: 0.005)')
parser.add_argument('--lr', type=float, default=0.0003, metavar='G',
                    help='learning rate (default: 0.0003)')
parser.add_argument('--alpha', type=float, default=0.2, metavar='G',
                    help='Temperature parameter α determines the relative importance of the entropy\
                            term against the reward (default: 0.2)')
parser.add_argument('--automatic_entropy_tuning', type=bool, default=False, metavar='G',
                    help='Automaically adjust α (default: False)')
parser.add_argument('--seed', type=int, default=123456, metavar='N',
                    help='random seed (default: 123456)')
parser.add_argument('--target_update_interval', type=int, default=1, metavar='N',
                    help='Value target update per no. of updates per step (default: 1)')
parser.add_argument('--hidden_size', type=int, default=256, metavar='N',
                    help='hidden size (default: 256)')
parser.add_argument('--cuda', action="store_true", default=True,
                    help='run on CUDA (default: False)')
parser.add_argument('--agent_kernel', default=None,
                    help='kernel: rf | nystrom | None (default: None)')
parser.add_argument('--adv_kernel', default=None,
                    help='kernel: rf | nystrom | None (default: None)')
parser.add_argument('--sigma',type = float, default=0,
                    help='sigma of noise (default: 1)')
parser.add_argument('--m', type = int, default=64,
                    help='number of features (default: 64)')
args = parser.parse_args()

np.random.seed(args.seed)







eps=0.1

env = simple_push_v3.env(render_mode="rgb_array", continuous_actions=True)

env.reset()




players = dict()

    
args.kernel = args.adv_kernel
players["adversary_0"] = SAC(env.state().shape[0], env.action_space("adversary_0"),  env, args)
args.kernel = args.agent_kernel
players["agent_0"] = SAC(env.state().shape[0], env.action_space("agent_0"),  env, args)




if args.adv_kernel is None:
    ckpt_path = "Checkpoints/{}_{}_{}_{}".format(env.metadata["name"],"adversary_0", args.sigma, "20000")
else:
    ckpt_path = "Checkpoints/{}_{}_{}_{}_{}_{}".format(env.metadata["name"],"adversary_0", args.adv_kernel,args.m,args.sigma,"20000")
players["adversary_0"].load_checkpoint(ckpt_path)


if args.agent_kernel is None:
    ckpt_path = "Checkpoints/{}_{}_{}_{}".format(env.metadata["name"],"agent_0", args.sigma,  "20000")
else:
    ckpt_path = "Checkpoints/{}_{}_{}_{}_{}_{}".format(env.metadata["name"],"agent_0", args.agent_kernel, args.m,args.sigma, "20000")
players["agent_0"].load_checkpoint(ckpt_path)


avg_reward = 0
episodes = 100000
agent_wins = 0
maxv = 0
maxp = 0
for i_episode in range(episodes):
    episode_reward = 0
    episode_steps = 0
    env.reset()
    observation, reward, termination, truncation, info = env.last()
    done =termination or truncation
    
    flag = 0
    
    for agent in env.agent_iter():  
        state = env.state()
        observation, reward, termination, truncation, info = env.last()
        done = termination or truncation

        if done:
            action = None
        else:
            action = players[agent].select_action(state)
        next_idx = env.next_idx()
        env.step(action)
        observation, reward, termination, truncation, info = env.last()
        if agent == "agent_0":
            agent_reward = reward
        else :
            adv_reward = reward
        if next_idx == 0:
            episode_steps += 1
            done = termination or truncation
            if done:
                break
            if episode_steps > 10:
                episode_reward += agent_reward
    
    episode_reward /= episode_steps
    if episode_reward >= 0:  
        agent_wins += 1
    avg_reward += episode_reward
    if i_episode%5000 == 0:
        print("i_episode:",i_episode,"agent_wins:",agent_wins )
avg_reward /= episodes
agent_win_rate = agent_wins / episodes

print("avg_reward:", avg_reward)
print("agent_win_rate:", agent_win_rate)
env.close()