import argparse
import os
import sys
from collections import defaultdict
import pprint

import torch
from torch import nn
import numpy as np

# path will be set in create_envs... or we can call set_path
from create_envs import create_default_env

import tube
import minirts
from pyxrl.data_channel_manager import DataChannelManager

from actor_critic import ActorCritic#, ActorCritic2
import common_utils

from rule_based_model import RuleBasedCountModel
from rule_ai_sampler import RuleAISampler
import global_consts as gc

from train_coach import train, evaluate


def parse_args():
    parser = argparse.ArgumentParser(description='rule based')
    parser.add_argument('--save_dir', type=str, default='dev/rule_dev')
    parser.add_argument('--seed', type=int, default=1)
    parser.add_argument('--deterministic', action='store_true')
    parser.add_argument('--num_thread', type=int, default=1)
    parser.add_argument('--batchsize', type=int, default=1)
    parser.add_argument('--gpu', type=int, default=0)
    parser.add_argument('--update_per_epoch', type=int, default=200)
    parser.add_argument('--num_epoch', type=int, default=400)

    root = os.path.dirname(
        os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
    default_lua = os.path.join(root, 'game/game_MC/lua')
    # default_lua = '../../build'
    parser.add_argument('--lua_files', type=str, default=default_lua)

    # optim
    parser.add_argument('--lr', type=float, default=6.25e-5)
    parser.add_argument('--eps', type=float, default=1.5e-4)
    parser.add_argument('--grad_clip', type=float, default=0.5)

    # enemy selection
    parser.add_argument('--adversarial', type=int, default=1)
    parser.add_argument('--win_rate_decay', type=float, default=0.99)

    # ai1 option
    parser.add_argument('--frame_skip', type=int, default=50)
    parser.add_argument('--fow', type=int, default=1)
    parser.add_argument('--t_len', type=int, default=10)
    parser.add_argument('--use_moving_avg', type=int, default=1)
    parser.add_argument('--moving_avg_decay', type=float, default=0.98)
    parser.add_argument('--num_resource_bins', type=int, default=11)
    parser.add_argument('--resource_bin_size', type=int, default=50)

    # # ai2 (enemy) option
    # # plus frame skip
    # parser.add_argument('--adversarial', type=int, default=1)
    # parser.add_argument('--adversarial_decay', type=float, default=0.7)

    # game option
    parser.add_argument('--max_tick', type=int, default=int(2e5))
    parser.add_argument('--no_terrain', action='store_true')
    parser.add_argument('--resource', type=int, default=500)
    parser.add_argument('--resource_dist', type=int, default=4)
    parser.add_argument('--fair', type=int, default=0)
    parser.add_argument('--save_replay_freq', type=int, default=50)
    parser.add_argument('--save_replay_per_games', type=int, default=1)
    parser.add_argument('--max_num_units', type=int, default=50)

    # actor crtic option
    parser.add_argument('--ppo', type=int, default=0)
    parser.add_argument('--ent_ratio', type=float, default=1e-2)
    parser.add_argument('--min_prob', type=float, default=1e-6)
    parser.add_argument('--max_importance_ratio', type=float, default=2.0)
    parser.add_argument('--ratio_clamp', type=float, default=0.1)
    parser.add_argument('--gamma', type=float, default=0.99)

    args = parser.parse_args()
    return args


def get_game_option(args):
    option = minirts.RTSGameOption()
    option.max_tick = args.max_tick
    option.no_terrain = args.no_terrain
    option.resource = args.resource
    option.resource_dist = args.resource_dist
    option.fair = args.fair
    option.save_replay_freq = args.save_replay_freq
    option.save_replay_per_games = args.save_replay_per_games
    option.lua_files = args.lua_files
    option.max_num_units_per_player = args.max_num_units
    return option


def get_ai_options(args):
    ai1_option = minirts.AIOption()
    ai1_option.fs = args.frame_skip
    ai1_option.fow = args.fow
    ai1_option.t_len = args.t_len
    ai1_option.use_moving_avg = args.use_moving_avg
    ai1_option.moving_avg_decay = args.moving_avg_decay
    ai1_option.num_resource_bins = args.num_resource_bins
    ai1_option.resource_bin_size = args.resource_bin_size

    ai2_option = minirts.AIOption()
    ai2_option.fs = args.frame_skip
    # ai2_option.adversarial = args.adversarial
    # ai2_option.adversarial_decay = args.adversarial_decay
    return ai1_option, ai2_option


if __name__ == '__main__':
    args = parse_args()
    print('args:')
    pprint.pprint(vars(args))

    os.environ['LUA_PATH'] = os.path.join(args.lua_files, '?.lua')
    print('lua path:', os.environ['LUA_PATH'])

    if args.deterministic:
        torch.backends.cudnn.deterministic = True
    else:
        torch.backends.cudnn.benchmark = True

    if args.save_dir:
        logger_path = os.path.join(args.save_dir, 'train.log')
        sys.stdout = common_utils.Logger(logger_path)

    ai1_option, ai2_option = get_ai_options(args)
    game_option = get_game_option(args)
    eval_option = minirts.RTSGameOption(game_option)
    eval_option.seed = 999111
    eval_option.num_games_per_thread = 1

    config, context, train_dc, act_dc, rule_dc, games = create_default_env(
        args.num_thread,
        args.batchsize,
        args.seed,
        ai1_option,
        ai2_option,
        game_option,
    )

    device = torch.device('cuda:%d' % args.gpu)
    model = RuleBasedCountModel(
        gc.NUM_STRATEGY,
        len(gc.UnitTypes),
        ai1_option.num_resource_bins).to(device)
    print(model)
    optim = common_utils.Optim(
        model, torch.optim.Adam, {'lr': args.lr, 'eps': args.eps}, args.grad_clip)
    method = ActorCritic(
        ent_ratio=args.ent_ratio,
        min_prob=args.min_prob,
        max_importance_ratio=args.max_importance_ratio,
        ratio_clamp=args.ratio_clamp,
        gamma=args.gamma,
        ppo=args.ppo)
    dc_manager = DataChannelManager([train_dc, act_dc, rule_dc])
    context.start()

    stat = common_utils.MultiCounter()
    train_result = common_utils.ResultStat(
        'reward', os.path.join(args.save_dir, 'train_win'))
    eval_result = common_utils.ResultStat(
        'reward', os.path.join(args.save_dir, 'eval_win'))
    rule_sampler = RuleAISampler(args.adversarial, 1, 1, args.win_rate_decay)

    for epoch in range(args.num_epoch):
        if epoch % 10 == 0:
            print('=============eval==============')
            with common_utils.eval_mode_no_grad():
                evaluate('rule',
                         model,
                         device,
                         500,
                         eval_option,
                         ai1_option,
                         epoch,
                         eval_result,
                         args.save_dir)
            print('========= end of eval==========')

        train(epoch,
              args.update_per_epoch,
              dc_manager,
              optim,
              method,
              model,
              None,
              rule_sampler,
              device,
              stat,
              train_result)

        stat.summary(epoch)
        print(train_result.log(epoch))
        rule_sampler.log()
        print('==========')

        train_result.reset()
        rule_sampler.reset()
        stat.reset()
