"""
Reinforcement learning maze example.

Red rectangle:          explorer.
Black rectangles:       hells       [reward = -1].
Yellow bin circle:      paradise    [reward = +1].
All other states:       ground      [reward = 0].

This script is the main part which controls the update method of this example.
The RL is in RL_brain.py.

View more on my tutorial page: https://morvanzhou.github.io/tutorials/
"""
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from sklearn.cluster import KMeans


from maze_denseR import Maze
from Centralized_Qlearning_new import QLearningTable

from scipy.special import softmax


episode_total_reward = []
Total_Rewards = 0

mean_episode_reward_list = []
task_completed_step = []

total_episode = 1000
total_step = 6

def run_maze():
    mean_episode_reward = 0
    e_a_1_list = []
    e_a_2_list = []
    e_a_3_list = []

    for episode in range(total_episode):
        # initial observation
        current_episode_reward = 0
        observation = env.reset(0)
        print("Start epsiode", episode)

        s_a_1_list = []
        s_a_2_list = []
        s_a_3_list = []

        for s in range(total_step):
            # fresh env
            env.render()

            # RL choose action based on observation
            action = RL.choose_action(observation)

            s_a_1_list.append(action[0])
            s_a_2_list.append(action[1])
            s_a_3_list.append(action[2])


            # RL take action and get next observation and reward
            #TODO change
            observation_, reward, done = env.step(action , s + 1)
            current_episode_reward += reward
            print('current rewards after', s, "steps is:", current_episode_reward)

            # RL learn from this transition
            RL.learn(observation, action, reward, observation_)

            # swap observation
            observation = observation_

            # break while loop when end of this episode
            if done:
                print("task achieved: YES!!!!!!!!!!!!")
                print("task achieved after ", s, " steps")
                task_completed_step.append(s)
                break
            s += 1
        e_a_1_list.append(s_a_1_list)
        e_a_2_list.append(s_a_2_list)
        e_a_3_list.append(s_a_2_list)

        if episode != 0 and episode % 200 == 0 and RL.epsilon <= 0.8:
            RL.update_epsilon()
        print('current episode reward when this episode ends:', current_episode_reward)
        episode_total_reward.append(current_episode_reward)
        #print('Total reward List after', episode,"episode is:", episode_total_reward)

        Total_Rewards = np.sum(episode_total_reward)
        mean_episode_reward = Total_Rewards / (episode + 1)
        mean_episode_reward_list.append(mean_episode_reward)

        print("Mean Episode Rewards after", episode, "episode is:", mean_episode_reward)
        #print("Mean Episode Rewards List:", mean_episode_reward_list)
        print("End Episode :", episode)
        print("\n")



        episode += 1
    print("training process over mean episode rewards:",
          mean_episode_reward)  # average rewards over 100 episodes without noise

    # end of game
    print('game over')

    pd.DataFrame(episode_total_reward).to_csv(
        "./baseline1_centralized_TotalReward_{episode}_{step}_test.csv".format(episode = total_episode, step = total_step))
    pd.DataFrame(mean_episode_reward_list).to_csv(
        "./baseline1_centralized_MeanEpisodeReward_{episode}_{step}_test.csv".format(episode=total_episode, step=total_step))
    pd.DataFrame(task_completed_step).to_csv(
        "./baseline1_centralized_task_completed_step_{episode}_{step}_test.csv".format(episode=total_episode, step=total_step))
    pd.DataFrame(e_a_1_list).to_csv(
        "./baseline1_centralized_task_completed_action_1_step_{episode}_{step}_test.csv".format(
            episode=total_episode, step=total_step))

    pd.DataFrame(e_a_2_list).to_csv(
        "./baseline1_centralized_task_completed_action_2_step_{episode}_{step}_test.csv".format(
            episode=total_episode, step=total_step))
    pd.DataFrame(e_a_3_list).to_csv(
        "./baseline1_centralized_task_completed_action_3_step_{episode}_{step}_test.csv".format(
            episode=total_episode, step=total_step))
    #print(RL.q_table)
    #pd.DataFrame(RL.q_table).to_csv(
    #    "./baseline1_centralized_task_completed_action_3_step_{episode}_{step}_test.csv".format(
    #        episode=total_episode, step=total_step))

    fix_agent_1 = []
    q_collection_1 = []
    group_1 = -1
    total_prob_count = 0
    prob_count_index_1 = []
    q_collection_1_softmax = []
    q_collection_1_tanh = []
    for s1 in RL.q_table.keys():
        for a1 in RL.q_table.get(s1).keys():
            q_fix_s1_a1 = []
            group_1 += 1
            prob_count = 0
            for s2 in RL.q_table.get(s1).get(a1).keys():
                for a2 in RL.q_table.get(s1).get(a1).get(s2).keys():
                    for s3 in RL.q_table.get(s1).get(a1).get(s2).get(a2).keys():
                        for a3 in RL.q_table.get(s1).get(a1).get(s2).get(a2).get(s3).keys():
                            fix_agent_1.append([s1, a1, s2, a2, s3, a3, RL.q_table[s1][a1][s2][a2][s3][a3], group_1, group_1, group_1]) #TODO: what group_1 means here
                            q_fix_s1_a1.append(RL.q_table[s1][a1][s2][a2][s3][a3])
                            prob_count += 1
                            total_prob_count += 1
                    q_collection_1.append(q_fix_s1_a1)
                    soft = softmax(q_fix_s1_a1)
                    soft_list = []
                    for s in soft:
                        soft_list.append(s)
                    q_collection_1_softmax.append(soft_list)

                    tanh = np.tanh(q_fix_s1_a1)
                    tanh_list = []
                    for t in tanh:
                        tanh_list.append(t)
                    q_collection_1_tanh.append(tanh_list)
                    max_value = max(q_fix_s1_a1)
                    max_index = q_fix_s1_a1.index(max_value)
                    prob_count_index_1.append([total_prob_count, prob_count, max_index])


    pd_q_1 = pd.DataFrame(q_collection_1)
    pd_q_1_fill_0 = pd_q_1.fillna(0.0)
    state_1 = KMeans(n_clusters=4, random_state=0).fit(pd_q_1_fill_0)
    label_1 = state_1.labels_

    pd_q_1_softmax = pd.DataFrame(q_collection_1_softmax)
    pd_q_1_fill_0_softmax = pd_q_1_softmax.fillna(0.0)
    state_1_softmax = KMeans(n_clusters=4, random_state=0).fit(pd_q_1_fill_0_softmax)
    label_1_softmax = state_1_softmax.labels_

    pd_q_1_tanh = pd.DataFrame(q_collection_1_tanh)
    pd_q_1_fill_0_tanh = pd_q_1_tanh.fillna(0.0)
    state_1_tanh = KMeans(n_clusters=4, random_state=0).fit(pd_q_1_fill_0_tanh)
    label_1_tanh = state_1_tanh.labels_

    fix_agent_2 = []
    q_collection_2 = []
    group_2 = -1
    total_prob_count = 0
    prob_count_index_2 = []
    q_collection_2_softmax = []
    q_collection_2_tanh = []
    for s2 in RL.q_table_2.keys():
        for a2 in RL.q_table_2.get(s2).keys():
            q_fix_s2_a2 = []
            group_2 += 1
            prob_count = 0
            for s1 in RL.q_table_2.get(s2).get(a2).keys():
                for a1 in RL.q_table_2.get(s2).get(a2).get(s1).keys():
                    for s3 in RL.q_table_2.get(s2).get(a2).get(s1).get(a1).keys():
                        for a3 in RL.q_table_2.get(s2).get(a2).get(s1).get(a1).get(s3).keys():
                            fix_agent_2.append(
                                [s2, a2, s1, a1, s3, a3, RL.q_table_2[s2][a2][s1][a1][s3][a3], group_2, group_2, group_2])
                            q_fix_s2_a2.append(RL.q_table_2[s2][a2][s1][a1][s3][a3])
                            prob_count += 1
                            total_prob_count += 1
                    q_collection_2.append(q_fix_s2_a2)
                    soft = softmax(q_fix_s2_a2)
                    soft_list = []
                    for s in soft:
                        soft_list.append(s)
                    q_collection_2_softmax.append(soft_list)

                    tanh = np.tanh(q_fix_s2_a2)
                    tanh_list = []
                    for t in tanh:
                        tanh_list.append(t)
                    q_collection_2_tanh.append(tanh_list)
                    max_value = max(q_fix_s2_a2)
                    max_index = q_fix_s2_a2.index(max_value)
                    prob_count_index_2.append([total_prob_count, prob_count, max_index])

    pd_q_2 = pd.DataFrame(q_collection_2)
    pd_q_2_fill_0 = pd_q_2.fillna(0.0)
    state_2 = KMeans(n_clusters=4, random_state=0).fit(pd_q_2_fill_0)
    label_2 = state_2.labels_


    pd_q_2_softmax = pd.DataFrame(q_collection_2_softmax)
    pd_q_2_fill_0_softmax = pd_q_2_softmax.fillna(0.0)
    state_2_softmax = KMeans(n_clusters=4, random_state=0).fit(pd_q_2_fill_0_softmax)
    label_2_softmax = state_2_softmax.labels_

    pd_q_2_tanh = pd.DataFrame(q_collection_2_tanh)
    pd_q_2_fill_0_tanh = pd_q_2_tanh.fillna(0.0)
    state_2_tanh = KMeans(n_clusters=4, random_state=0).fit(pd_q_2_fill_0_tanh)
    label_2_tanh = state_2_tanh.labels_

    fix_agent_3 = []
    q_collection_3 = []
    group_3 = -1
    total_prob_count = 0
    prob_count_index_3 = []
    q_collection_3_softmax = []
    q_collection_3_tanh = []
    for s3 in RL.q_table.keys():
        for a3 in RL.q_table.get(s3).keys():
            q_fix_s3_a3 = []
            group_3 += 1
            prob_count = 0
            for s1 in RL.q_table.get(s3).get(a3).keys():
                for a1 in RL.q_table.get(s3).get(a3).get(s1).keys():
                    for s2 in RL.q_table.get(s3).get(a3).get(s1).get(a1).keys():
                        for a2 in RL.q_table.get(s3).get(a3).get(s1).get(a1).get(s2).keys():
                            fix_agent_3.append(
                                [s3, a3, s1, a1, s2, a2, RL.q_table[s3][a3][s1][a1][s2][a2], group_3, group_3,
                                 group_3])  # TODO: what group_1 means here
                            q_fix_s3_a3.append(RL.q_table[s3][a3][s1][a1][s2][a2])
                            prob_count += 1
                            total_prob_count += 1
                    q_collection_3.append(q_fix_s3_a3)
                    soft = softmax(q_fix_s3_a3)
                    soft_list = []
                    for s in soft:
                        soft_list.append(s)
                    q_collection_3_softmax.append(soft_list)

                    tanh = np.tanh(q_fix_s3_a3)
                    tanh_list = []
                    for t in tanh:
                        tanh_list.append(t)
                    q_collection_3_tanh.append(tanh_list)
                    max_value = max(q_fix_s3_a3)
                    max_index = q_fix_s3_a3.index(max_value)
                    prob_count_index_3.append([total_prob_count, prob_count, max_index])

    pd_q_3 = pd.DataFrame(q_collection_3)
    pd_q_3_fill_0 = pd_q_3.fillna(0.0)
    state_3 = KMeans(n_clusters=4, random_state=0).fit(pd_q_3_fill_0)
    label_3 = state_3.labels_

    pd_q_3_softmax = pd.DataFrame(q_collection_3_softmax)
    pd_q_3_fill_0_softmax = pd_q_3_softmax.fillna(0.0)
    state_3_softmax = KMeans(n_clusters=4, random_state=0).fit(pd_q_3_fill_0_softmax)
    label_3_softmax = state_3_softmax.labels_

    pd_q_3_tanh = pd.DataFrame(q_collection_3_tanh)
    pd_q_3_fill_0_tanh = pd_q_3_tanh.fillna(0.0)
    state_3_tanh = KMeans(n_clusters=4, random_state=0).fit(pd_q_3_fill_0_tanh)
    label_3_tanh = state_3_tanh.labels_

    index = 0
    for i, row in enumerate(fix_agent_1):
        row[-3] = label_1[row[-3]]
        row[-2] = label_1_softmax[row[-2]]
        row[-1] = label_1_tanh[row[-1]]
        if i == prob_count_index_1[index][0] and index <= len(prob_count_index_1):
            index += 1
        if i < prob_count_index_1[index][0] and index <= len(prob_count_index_1):
            row.append(prob_count_index_1[index][1] / len(fix_agent_1))
            row.append(prob_count_index_1[index][2])
        row.append(env.location_count_1[row[0]] / (total_episode * total_step))
        row.append(env.location_count_2[row[2]] / (total_episode * total_step))
        row.append(env.location_count_3[row[4]] / (total_episode * total_step))


    index = 0
    for i, row in enumerate(fix_agent_2):
        row[-3] = label_2[row[-3]]
        row[-2] = label_2_softmax[row[-2]]
        row[-1] = label_2_tanh[row[-1]]
        if i == prob_count_index_2[index][0] and index <= len(prob_count_index_2):
            index += 1
        if i < prob_count_index_2[index][0] and index <= len(prob_count_index_2):
            row.append(prob_count_index_2[index][1] / len(fix_agent_2))
            row.append(prob_count_index_2[index][2])
        row.append(env.location_count_1[row[0]] / (total_episode * total_step))
        row.append(env.location_count_2[row[2]] / (total_episode * total_step))
        row.append(env.location_count_3[row[4]] / (total_episode * total_step))


    index = 0
    for i, row in enumerate(fix_agent_3):
        row[-3] = label_3[row[-3]]
        row[-2] = label_3_softmax[row[-2]]
        row[-1] = label_3_tanh[row[-1]]
        if i == prob_count_index_3[index][0] and index <= len(prob_count_index_3):
            index += 1
        if i < prob_count_index_3[index][0] and index <= len(prob_count_index_3):
            row.append(prob_count_index_3[index][1] / len(fix_agent_3))
            row.append(prob_count_index_3[index][2])
        row.append(env.location_count_1[row[0]] / (total_episode * total_step))
        row.append(env.location_count_2[row[2]] / (total_episode * total_step))
        row.append(env.location_count_3[row[4]] / (total_episode * total_step))


    pd1 = pd.DataFrame(fix_agent_1)
    pd1.columns = ['s1', 'a1', 's2', 'a2', 's3', 'a3', 'q', 'label', 'softmax_label', 'tanh_label', 'percentage','max_index', 's1_prob', 's2_prob', 's3_prob']
    pd1.to_csv("./baseline1_central_agent_1_q_{episode}_{step}_test.csv".format(episode = total_episode, step = total_step))

    pd2 = pd.DataFrame(fix_agent_2)
    pd2.columns = ['s2', 'a2', 's1', 'a1', 's3', 'a3','q', 'label','softmax_label', 'tanh_label', 'percentage','max_index', 's1_prob', 's2_prob','s3_prob']
    pd2.to_csv("./baseline1_central_agent_2_q_{episode}_{step}_test.csv".format(episode = total_episode, step = total_step))

    pd3 = pd.DataFrame(fix_agent_3)
    pd3.columns = ['s3', 'a3', 's1', 'a1', 's2', 'a2', 'q', 'label', 'softmax_label', 'tanh_label', 'percentage', 'max_index','s1_prob', 's2_prob', 's3_prob']
    pd3.to_csv("./baseline1_central_agent_3_q_{episode}_{step}_test.csv".format(episode=total_episode, step=total_step))



    env.destroy()

if __name__ == "__main__":
    env = Maze()
    RL = QLearningTable(actions=list(range(1, env.n_actions + 1)))

    env.after(100, run_maze)
    env.mainloop()

    plt.plot(np.arange(len(episode_total_reward)), episode_total_reward)
    plt.xlabel('Episode')
    plt.ylabel('Total reward')
    plt.show()

    plt.plot(np.arange(len(task_completed_step)), task_completed_step)
    plt.xlabel('Episode')
    plt.ylabel('Task Completed Steps in current episode')
    plt.show()

    plt.plot(np.arange(len(mean_episode_reward_list)), mean_episode_reward_list)
    plt.xlabel('Episode')
    plt.ylabel('Mean Episode reward')
    plt.show()


