import numpy as np
import itertools
from pomdp_env import sample, POMDP


class AsymmetricQ:
    def __init__(self, p, Z, alpha, epsilon):
        self.Q = []
        self.U = []
        self.as_space = []
        self.s_space = []
        self.p = p
        self.Z = Z
        self.alpha = alpha
        self.epsilon = epsilon
        for i in range(p.H + 1):
            shape_list = []
            for _ in range(min(i, Z)):
                shape_list.append(p.o_num)
                shape_list.append(p.a_num)
            shape_list.extend([p.o_num, p.s_num])
            as_list = [list(range(shape)) for shape in shape_list]
            as_space = list(itertools.product(*as_list))
            self.as_space.append(as_space)
            shape_list = shape_list[:-1]
            s_list = [list(range(shape)) for shape in shape_list]
            s_space = list(itertools.product(*s_list))
            self.s_space.append(s_space)
            Q_dict = dict()
            U_dict = dict()
            for s in as_space:
                Q_dict[tuple(s)] = np.zeros(shape=(p.a_num,))
            for s in s_space:
                U_dict[tuple(s)] = np.zeros(shape=(p.a_num,))
            self.Q.append(Q_dict)
            self.U.append(U_dict)
            # if i == p.H-1:
            #     Q_dict = dict()
            #     U_dict = dict()
            #     for s in as_space:
            #         Q_dict[tuple(s)] = np.zeros(shape=(p.a_num,))
            #     for s in s_space:
            #         U_dict[tuple(s)] = np.zeros(shape=(p.a_num,))
            #     self.Q.append(Q_dict)
            #     self.U.append(U_dict)

            # self.Q.append(np.zeros(shape=tuple(shape_list)))
            # s_list = [list(shape_list[i]) for i in range(len(shape_list))]
            # self.s_space.append(list(itertools.product(*s_list)))
        self.est = dict()
        for state, a in itertools.product(list(range(p.s_num)), list(range(p.a_num))):
            self.est[(state, a)] = np.zeros(shape=(p.s_num, p.o_num))
        # self.Q_list = [self.Q]

    def learn(self, K):
        r_sum_list = []
        for k in range(K):
            _, _ = self.run_traj(k)
            r_sum_list.append(self.evaluate(100))
            print(r_sum_list[-1])
        return r_sum_list

    def update_Q(self, traj):
        for h in range(self.p.H):
            self.est[(traj[4 * h], traj[4 * h + 2])][traj[4 * h + 4], traj[4 * h + 5]] += 1
        for i in range(self.p.H - 1):
            h = self.p.H - 2 - i
            for s in self.as_space[h]:
                for a in range(self.p.a_num):
                    state = s[-1]
                    expectation = 0
                    if np.sum(self.est[(state, a)]) < 0.5:
                        emp_tran = np.ones(shape=(self.p.s_num, self.p.o_num)) / (self.p.s_num * self.p.o_num)
                    else:
                        emp_tran = self.est[(state, a)] / np.sum(self.est[(state, a)])
                    for state_prime, o_prime in itertools.product(list(range(self.p.s_num)), list(range(self.p.o_num))):
                        if emp_tran[state_prime, o_prime] < 0.01:
                            continue
                        new_s_for_pi = self.truncate(s[:-1] + (a, o_prime))
                        new_s = new_s_for_pi + (state_prime,)
                        for a_prime in range(self.p.a_num):
                            expectation += emp_tran[state_prime, o_prime] * self.pi[h + 1][new_s_for_pi][a_prime] * \
                                           self.Q[h + 1][new_s][a_prime]
                    self.Q[h][s][a] = self.p.reward[s[-1], a] + expectation

    def truncate(self, his):
        if len(his) > self.Z * 2 + 1:
            s = tuple(his[-(self.Z * 2 + 1):])
        else:
            s = tuple(his)
        return s

    def update_pi(self):
        Q = self.Q_list[-1]
        for i in range(self.p.H - 1):
            h = self.p.H - 2 - i
            for s in self.s_space[h]:
                for a in range(self.p.a_num):
                    belief_weighted_q = 0
                    for state in range(self.p.s_num):
                        belief_weighted_q += self.p.belief(s, state, h <= self.Z) * Q[h][s + (state,)][a]
                    self.pi[h][s][a] *= np.exp(self.tau * belief_weighted_q / (h + 1))
                self.pi[h][s] /= self.pi[h][s].sum()

    def run_traj(self, epi, update=True):
        state, o = self.p.reset()
        his = [o]
        r_sum = 0
        extended_hist = [state, o]
        for h in range(self.p.H):
            old_his = self.truncate(his)
            if update:
                a = self.epsilon_greedy(self.U[h][old_his], epi)
            else:
                a = sample(self.U[h][old_his], argmax=True)
            state, o, r, _ = self.p.step(a)
            his += [a, o]
            if update:
                next_best = np.argmax(self.U[h + 1][self.truncate(his)])
                target = r + self.Q[h + 1][self.truncate(his) + (state,)][next_best]
                last_state = extended_hist[-2]
                self.U[h][old_his][a] = (1 - self.alpha) * self.U[h][old_his][a] + self.alpha * target
                self.Q[h][old_his + (last_state,)][a] = (1 - self.alpha) * self.Q[h][old_his + (last_state,)][
                    a] + self.alpha * target
            extended_hist += [a, r, state, o]
            r_sum += r
        return extended_hist, r_sum

    def epsilon_greedy(self, dist, epi):
        epsilon = (self.p.H + 1) / (self.p.H + epi)
        if np.random.random() < epsilon:
            return sample(np.ones(shape=(self.p.a_num,)) / self.p.a_num)
        else:
            return sample(dist, argmax=True)

    def evaluate(self, K):
        r_sum = 0
        for k in range(K):
            r_sum += self.run_traj(0, update=False)[1]
        return r_sum / K
