from BaseAgent import BaseAgent
import numpy as np

class RegQ(BaseAgent):
    def __init__(self, env, config):

        super().__init__(env)

        self.env = env
        self.gamma = config["gamma"]
        self.alpha = config["alpha"]
        self.eta = config['eta']
        self.updates = 0
        self.weights = None

        self.init_weight(self.env.env_name)

    def primal_weight(self):
        return self.weights

    def dual_weight(self):
        return self.weights

    def init_weight(self, weight_initializer):
        init_type = {"Baird": self.baird_weight, "ThetaTwoTheta": self.theta_two_theta}
        init_type[weight_initializer]()

    def theta_two_theta(self):
        self.weights = np.ones(self.num_features)

    def baird_weight(self):

        self.weights = np.ones(self.num_features)
        self.weights[self.env.SEVENTH_STATE] = 10

    def td_error(self, state, action, next_state, reward, done_mask):

        phi = self.features[action, state]
        q_sa_theta = self.action_value(state, action, self.weights)
        next_action = self.greedy_policy(next_state, self.weights)
        next_q_sa_theta = self.action_value(next_state, next_action, self.weights)

        td_error = reward + done_mask * self.gamma * next_q_sa_theta - q_sa_theta

        return td_error

    def update_weight(self, state, action, next_state, reward, done_mask):

        td_error = self.td_error(state, action, next_state, reward, done_mask)
        phi = self.features[action, state]

        gradient = td_error * phi
        #
        # self.writer.add_scalar("abs(gradient)", max(abs(gradient)), self.updates)
        # self.writer.add_scalar("lr", self.lr_alpha.lr, self.updates)
        #
        # phi_w = np.einsum('i,i->', phi, self.weights)
        self.weights = self.weights + self.alpha * (gradient-self.eta * self.weights)

    def update(self, state, next_state, action, reward, done_mask):

        self.update_weight(state, action, next_state, reward, done_mask)

        self.updates += 1