from strategy import Strategy
from game import Game
import numpy as np
from numpy.typing import NDArray
import random
from stochastic_game import GameWithTransitions

class FictitiousPlay(Strategy[None]):

    def __init__(self, g: Game[None], player:int) -> None:
        self.g = g
        self.player = player
        self.max_action = max([g.maxAction(None, i) for i in range(g.maxPlayer)])
        self.x = np.zeros((g.maxPlayer, self.max_action))
        self.t = 0
        super().__init__()
    
    def myeye(self, n: int) -> list[float]:
        distr = [0.]*self.max_action
        distr[n] = 1.
        return distr

    def choice(self, state: None) -> int:
        expected : list[list[float]] = self.x.tolist()
        best_r = -1.
        best_a = 0
        for a in range(self.max_action):
            expected[self.player] = self.myeye(a)
            r = self.g.mixedRewards(None,expected)[self.player]
            if r >= best_r:
                best_a = a
                best_r = r
        return best_a

    def mixedAction(self, state: None) -> list[float]:
        action = self.choice(state)
        distr = self.myeye(action)
        return distr

    def getReward(self, state: None, next_state: None, action:int, reward:float) -> None:
        pass
    
    def getInformedReward(self, state: None, next_state: None, player: int, actions: list[int], reward: float) -> None:
        distr = np.array(list(map(self.myeye, actions)))
        t = self.t
        self.x = t*self.x/(t+1) + 1/(t+1)*distr

    def description(self) -> str:
        return "FP"


class SmoothFictitiousPlay(FictitiousPlay):
    def __init__(self, g: Game[None], player: int, smness: float) -> None:
        super().__init__(g, player)
        self.smness = smness
    
    def rwd(self) -> NDArray[np.float64]:
        expected : list[list[float]] = self.x.tolist()
        r = np.zeros(self.max_action)
        for a in range(self.max_action):
            expected[self.player] = self.myeye(a)
            r[a] = self.g.mixedRewards(None,expected)[self.player]
        return r

    def mixedAction(self, state: None) -> list[float]:
        r = self.rwd()
        r = r * self.smness
        r = np.exp(r)
        s: float = np.sum(r)
        r = r/s
        return list(r)
    
    def choice(self, state: None) -> int:
        ma = self.mixedAction(state)
        return random.choices(range(self.max_action), weights=ma)[0]

class SFPStochasticGames(Strategy[int]):
    def __init__(self, g:GameWithTransitions[int], pl: int, smness: float):
        super().__init__()
        self.g = g
        assert(g.maxPlayer == 2)
        self.max_states = len(g.all_states())
        s0 = g.all_states().pop()
        self.max_action_p1 = g.maxAction(s0, pl)
        self.max_action_p2 = g.maxAction(s0, 1-pl)
        #max([g.maxAction(s0, i) for i in range(g.maxPlayer)])
        self.x1 = np.zeros((self.max_states, self.max_action_p1))
        self.x2 = np.zeros((self.max_states, self.max_action_p2))
        self.u = np.ones((self.max_states))*20
        self.tr = np.zeros((self.max_states, self.max_action_p1, self.max_action_p2, self.max_states))
        self.r = np.ones((self.max_states, self.max_action_p1, self.max_action_p2))*20
        self.ta = np.zeros((self.max_states, self.max_action_p1, self.max_action_p2))
        self.smness = smness
        self.pl = pl
        self.disc = 0.5
        self.t = 0
        self.ts = np.zeros(self.max_states)

    def myeye(self, n: int, t:int=-1) -> list[float]:
        if t == -1:
            t = self.max_action_p1
        distr = [0.]*t
        distr[n] = 1.
        return distr

    def mixedAction(self, state: int) -> list[float]:
        f = np.zeros(self.max_action_p1)
        for i in range(len(f)):
            opp_strat = self.x2[state, :]
            transitions = np.matmul(self.tr[state, i, :, :].T, opp_strat)
            f[i] = (1-self.disc) * np.dot(self.r[state, i, :], opp_strat)  + self.disc * np.dot(self.u, transitions)
        f = np.exp(self.smness * f)
        f = f/np.sum(f)
        return f

    def choice(self, state: int) -> int:
        ma = self.mixedAction(state)
        return random.choices(range(self.max_action_p1), weights=ma)[0]
    
    def getInformedReward(self, state: int, next_state: int, player: int, actions: list[int], reward: float):
        act = actions
        if player == 1:
            act = (actions[1], actions[0])
        ta = self.ta[state, act[0], act[1]]
        self.tr[state, act[0], act[1], :] += (np.array(self.myeye(next_state, t=self.max_states)) - self.tr[state, act[0], act[1], :] )/(ta+1)
        self.r[state, act[0], act[1]] += (reward-self.r[state, act[0], act[1]])/(ta+1)
        for (k, x) in ((0, self.x1), (1, self.x2)):
            x[state, :] += (np.array(self.myeye(act[k], t=len(x[state, :]))) - x[state, :] )/(self.ts[state]+1)
        self.ts[state] = self.ts[state] + 1
        for s in range(len(self.u)):
            f = np.zeros(self.max_action_p1)
            for i in range(len(f)):
                opp_strat = self.x2[s, :]
                transitions = np.matmul(self.tr[s, i, :, :].T, opp_strat)
                f[i] = (1-self.disc) * np.dot(self.r[s, i, :],  opp_strat)  + self.disc * np.dot(self.u.T, transitions)
            self.u[s] = self.u[s] + (np.dot(self.x1[s, :], f) -self.u[s])/(self.t+1)

        self.ta[state, act[0], act[1]] += 1
        self.t += 1
        self.ts[state] += 1


    def getReward(self, state: int, next_state: int, action: int, reward: float) -> None:
        raise Exception("unreachable")
    
    def description(self) -> str:
        return "SFPStochastic"
