"""
This part of code is the Q learning brain, which is a brain of the agent.
All decisions are made in here.

View more on my tutorial page: https://morvanzhou.github.io/tutorials/

Centralized Q learning
"""
from collections import defaultdict

import numpy as np
import pandas as pd


class QLearningTable:
    def __init__(self, actions, learning_rate=0.1, reward_decay=0.9, e_greedy=0.2):
        self.actions = actions  # a list
        self.lr = learning_rate
        self.gamma = reward_decay
        self.epsilon = e_greedy
        self.q_table = defaultdict(lambda : defaultdict(lambda : defaultdict(lambda :defaultdict(lambda :defaultdict(lambda :defaultdict(lambda : np.random.uniform(-0.08,0,1)))))))
        self.q_table_2 = defaultdict(lambda : defaultdict(lambda : defaultdict(lambda :defaultdict(lambda :defaultdict(lambda :defaultdict(lambda : np.random.uniform(-0.08,0,1)))))))
        self.q_table_3 = defaultdict(lambda : defaultdict(lambda : defaultdict(lambda :defaultdict(lambda :defaultdict(lambda :defaultdict(lambda : np.random.uniform(-0.08,0,1)))))))

    def update_epsilon(self):
        self.epsilon += 0.1
        print('updated epsilon value after every 200 episodes to value:', self.epsilon)

    def choose_action(self, observation):
        # action selection

        print('current epsilon value is:', self.epsilon)

        if np.random.uniform() < self.epsilon:
            # choose best action
            action = [np.random.choice(self.actions), np.random.choice(self.actions), np.random.choice(self.actions)]
            # will return [Q(s,a1), Q(s,a2), Q(s,a3), Q(s,a4)]
            # some actions may have the same value, randomly choose on in these actions
            max_q_value = float("-inf")
            for action_0 in self.actions:
                for action_1 in self.actions:
                    for action_2 in self.actions:
                        assert action_0 > 0 and action_1 > 0 and action_2 > 0
                        q_val = self.q_table[str(observation[0])][str(action_0)][str(observation[1])][str(action_1)][str(observation[2])][str(action_2)]
                        q_val_2 = self.q_table_2[str(observation[1])][str(action_1)][str(observation[0])][str(action_0)][str(observation[2])][str(action_2)]
                        q_val_3 = self.q_table_2[str(observation[2])][str(action_2)][str(observation[0])][str(action_0)][str(observation[1])][str(action_1)]

                        if q_val > max_q_value:
                            max_q_value = q_val
                            action = [action_0, action_1, action_2]
        else:
            # choose random action
            action = [np.random.choice(self.actions), np.random.choice(self.actions), np.random.choice(self.actions)]
        #print("current pos: {}, greedy aciton: {}".format(observation, action)) #TODO

        return action

    def learn(self, s, a, r, s_):
        q_predict = self.q_table[str(s[0])][str(a[0])][str(s[1])][str(a[1])][str(s[2])][str(a[2])]# use q table to get the corresponding Q(s,a) for each agent individually
        q_predict_2 = self.q_table_2[str(s[1])][str(a[1])][str(s[0])][str(a[0])][str(s[2])][str(a[2])]
        q_predict_3 = self.q_table_2[str(s[2])][str(a[2])][str(s[0])][str(a[0])][str(s[1])][str(a[1])]

        if s_ != 'terminal':
            max_q_value = float("-inf")
            for action_0 in self.actions:
                for action_1 in self.actions:
                    for action_2 in self.actions:
                        q_val = self.q_table[str(s_[0])][str(action_0)][str(s_[1])][str(action_1)][str(s_[2])][str(action_2)]
                        q_val_2 = self.q_table[str(s_[1])][str(action_1)][str(s_[0])][str(action_0)][str(s_[2])][str(action_2)]
                        q_val_3 = self.q_table[str(s_[2])][str(action_2)][str(s_[0])][str(action_0)][str(s_[1])][str(action_1)]

                        if q_val is not None and q_val > max_q_value:
                            max_q_value = q_val

            q_target = r + self.gamma * max_q_value  # next state is not terminal
        else:
            q_target = r  # next state is terminal
        self.q_table[str(s[0])][str(a[0])][str(s[1])][str(a[1])][str(s[2])][str(a[2])] += self.lr * (q_target - q_predict)# update
        self.q_table_2[str(s[1])][str(a[1])][str(s[0])][str(a[0])][str(s[2])][str(a[2])] += self.lr * (q_target - q_predict_2)
        self.q_table_3[str(s[2])][str(a[2])][str(s[0])][str(a[0])][str(s[1])][str(a[1])] += self.lr * (q_target - q_predict_3)



