import numpy as np
import pandas as pd
from matplotlib import pyplot as plt

from QLearning import QLearningTable
from joblib import load

from arguments.args import get_args
from arguments.utils import make_env

import argparse
import random

import os

args = get_args()

# Set the random seed for reproducibility
random_seed = args.random_seed  # You can choose any seed value
np.random.seed(random_seed)
random.seed(random_seed)

model_dir_BASELINES = os.path.join('outputs', 'Step2_BaselineDTModels', str(args.max_depth_baseline))

models = {
    "CART": {
        "agent_1": load(os.path.join(model_dir_BASELINES, 'tree_traditional_agent_1_MaxDepth_{max_depth_baseline}.joblib'.format(max_depth_baseline=args.max_depth_baseline))),
        "agent_2": load(os.path.join(model_dir_BASELINES, 'tree_traditional_agent_2_MaxDepth_{max_depth_baseline}.joblib'.format(max_depth_baseline=args.max_depth_baseline))),
        "agent_3": load(os.path.join(model_dir_BASELINES, 'tree_traditional_agent_3_MaxDepth_{max_depth_baseline}.joblib'.format(max_depth_baseline=args.max_depth_baseline)))
    },
    "Random Forest": {
        "agent_1": load(os.path.join(model_dir_BASELINES, 'random_forest_agent_1_MaxDepth_{max_depth_baseline}.joblib'.format(max_depth_baseline=args.max_depth_baseline))),
        "agent_2": load(os.path.join(model_dir_BASELINES, 'random_forest_agent_2_MaxDepth_{max_depth_baseline}.joblib'.format(max_depth_baseline=args.max_depth_baseline))),
        "agent_3": load(os.path.join(model_dir_BASELINES, 'random_forest_agent_3_MaxDepth_{max_depth_baseline}.joblib'.format(max_depth_baseline=args.max_depth_baseline)))
    },
    "GBDT": {
        "agent_1": load(os.path.join(model_dir_BASELINES, 'gbdt_agent_1_MaxDepth_{max_depth_baseline}.joblib'.format(max_depth_baseline=args.max_depth_baseline))),
        "agent_2": load(os.path.join(model_dir_BASELINES, 'gbdt_agent_2_MaxDepth_{max_depth_baseline}.joblib'.format(max_depth_baseline=args.max_depth_baseline))),
        "agent_3": load(os.path.join(model_dir_BASELINES, 'gbdt_agent_3_MaxDepth_{max_depth_baseline}.joblib'.format(max_depth_baseline=args.max_depth_baseline)))
    },
    "Extra Trees": {
        "agent_1": load(os.path.join(model_dir_BASELINES, 'extra_trees_agent_1_MaxDepth_{max_depth_baseline}.joblib'.format(max_depth_baseline=args.max_depth_baseline))),
        "agent_2": load(os.path.join(model_dir_BASELINES, 'extra_trees_agent_2_MaxDepth_{max_depth_baseline}.joblib'.format(max_depth_baseline=args.max_depth_baseline))),
        "agent_3": load(os.path.join(model_dir_BASELINES, 'extra_trees_agent_3_MaxDepth_{max_depth_baseline}.joblib'.format(max_depth_baseline=args.max_depth_baseline)))
    }
}

def run_maze_with_model(model_name, model_0, model_1, model_2, env, evaluate_episodes, evaluate_episode_len):
    episode_total_reward = []
    Total_Rewards = 0
    mean_episode_reward_list = []
    task_completed_step = []
    for episode in range(evaluate_episodes):
        current_episode_reward = 0
        observation = env.reset()
        print("Start episode", episode)

        for s in range(evaluate_episode_len):
            # fresh env
            env.render()
            # Reshape the observation to have shape (1, n_features)
            # Assuming `observation` here is a NumPy array; if not, you may need to convert it
            observation_reshaped_agent_1 = observation[0]
            observation_reshaped_agent_2 = observation[1]
            observation_reshaped_agent_3 = observation[2]
            agent_1_input_observation = observation_reshaped_agent_1[0]
            agent_2_input_observation = observation_reshaped_agent_2[0]
            agent_3_input_observation = observation_reshaped_agent_3[0]

            # Now, use the reshaped observation for prediction
            action_0 = model_0.predict(agent_1_input_observation.reshape(1, -1))
            action_1 = model_1.predict(agent_2_input_observation.reshape(1, -1))
            action_2 = model_2.predict(agent_3_input_observation.reshape(1, -1))

            action_n = [action_0, action_1, action_2]

            observation_, reward, done = env.step(action_n)  # Execute the action
            current_episode_reward += reward
            print('current episode reward:', current_episode_reward)
            # swap observation
            observation = observation_

            # break while loop when end of this episode
            if done:
                print("task achieved: YES!!!!!!!!!!!!")
                print("task achieved after ", s, " steps")
                task_completed_step.append(s)
                break
            s += 1

        print('current episode reward when this episode ends:', current_episode_reward)
        episode_total_reward.append(current_episode_reward)
        print('Total reward List after', episode, "episode is:", episode_total_reward)
        Total_Rewards = np.sum(episode_total_reward)
        mean_episode_reward = Total_Rewards / (episode + 1)
        mean_episode_reward_list.append(mean_episode_reward)

        print("Mean Episode Rewards after", episode, "episode is:", mean_episode_reward)
        print("Mean Episode Rewards List:", mean_episode_reward_list)
        print("End Episode :", episode)
        print("\n")
        episode += 1
    print("training process over mean episode rewards:",
          mean_episode_reward)  # average rewards over 100 episodes without noise

    # end of game
    print('game over')
    save_results(model_name, episode_total_reward, mean_episode_reward_list, task_completed_step)

def save_results(model_name, episode_total_reward, mean_episode_reward_list, task_completed_step):
    save_dir = os.path.join('outputs', 'Step3_Evaluate_BaselinesDT',
                            f'Max_Depth_{args.max_depth_baseline}', f'MeanEpisodeRewardList_RandomSeed_{args.random_seed}',
                            model_name)
    os.makedirs(save_dir, exist_ok=True)

    pd.DataFrame(episode_total_reward).to_csv(
        os.path.join(save_dir, f"EpisodeTotalReward_{evaluate_episodes}_{evaluate_episode_len}.csv"))
    pd.DataFrame(mean_episode_reward_list).to_csv(
        os.path.join(save_dir, f"MeanEpisodeReward_{evaluate_episodes}_{evaluate_episode_len}.csv"))
    pd.DataFrame(task_completed_step).to_csv(
        os.path.join(save_dir, f"TaskCompleteStep_{evaluate_episodes}_{evaluate_episode_len}.csv"))

if __name__ == "__main__":
    # Initialization of arguments and environment setup
    args = get_args()
    env, args = make_env(args)
    evaluate_episodes = args.evaluate_episodes  # Adjust based on your experiment's requirements
    evaluate_episode_len = args.evaluate_episode_len  # Adjust as necessary
    # Evaluate all models
    for model_name, model_pair in models.items():
        print(f"Evaluating model: {model_name}")
        run_maze_with_model(model_name, model_pair['agent_1'], model_pair['agent_2'], model_pair['agent_3'], env, evaluate_episodes, evaluate_episode_len)
