#%% Import Libraries

import numpy as np
import matplotlib.pyplot as plt




#%% Initialise class

class Elo():
    def __init__(self, n = 2, tru_rat = None, est_rat = None, K = 0.005):
        self.n = n  # number of players
        self.K = K  # K-factor
            # chess uses K = 10, but scales difference by log(10)/400 ~ 0.006
        
        if tru_rat is None:
            tru_rat = np.zeros(self.n, dtype = float)
        elif isinstance(tru_rat, str) and tru_rat == "normal":
            tru_rat = np.random.standard_normal(n)
            tru_rat -= tru_rat.mean()
            if tru_rat[0] < 0:  # force first player to be positively rated
                tru_rat = - tru_rat
            est_rat = tru_rat
        self.tru_rat = np.array(tru_rat)
        
        self.tru_prob = np.zeros([self.n, self.n], dtype = float)
        for i in range(self.n):
            for j in range(self.n):
                self.tru_prob[i,j] = 1. / (1 + np.exp(self.tru_rat[j] - self.tru_rat[i]))
        
        if est_rat is None:
            self.R = np.zeros(self.n, dtype = float)
        elif n == 2 and len(est_rat) == 1:
            self.R = np.array([est_rat, -est_rat], dtype = float)
        else:
            self.R = est_rat
        
        # histories of Player 0's rating and win probability
        self.hist_R = [self.R[0]]
        self.hist_M = [np.abs(self.R).max()]
        self.R_max = np.abs(self.R).max()
        self.hist_P = []
        self.hist_err = [np.linalg.norm(self.R-tru_rat)**2]
        self.R_time_avg = self.R.copy()
        self.hist_err_avg = [np.linalg.norm(self.R-tru_rat)**2]
        pass

        
    
    def check_zero_sum(self):
        if not np.isclose(self.R.sum(), 0):
            print(f"Ratings = {self.R} not (close to) zero sum")
        else:
            print(f"Ratings = {self.R} are (close to) zero sum")
    
    def prob_fn(self, d):
        # P(P0 beat P1) when R0 - R1 = d
        # if d > 0, prob_fn > 1/2
        return 1. / (1 + np.exp(-d))
    
    def get_ratings(self, i = None, PRINT = False):
        if i is None:
            if PRINT:
                print(f"Ratings: {self.R}")
            return self.R
        else:
            if PRINT:
                print(f"Rating of Player {i}: {self.R[i]}")
            return self.R[i]
        pass
    
    def get_tru_prob(self, winner, loser, PRINT = False):
        p = self.tru_prob[winner, loser]
        if PRINT:
            print(f"True probability Player {winner} beats Player {loser}) = {p:.2f}")
        return p
    
    def get_est_prob(self, i, j, PRINT = False):
        p = self.prob_fn(self.R[i] - self.R[j])
        if PRINT:
            print(f"Estimated probability Player {i} beats Player {j}) = {p:.2f}")
        return p
    
    def play(self, i = None, j = None, winner = None, record = None, PRINT = False):
        if i is None:
            i = np.random.choice(np.arange(self.n))
        if j is None:
            j = np.random.choice(np.arange(self.n))
        if i == j:
            i,j = np.random.choice(np.arange(self.n), size = 2, replace = False)
        
        p_t = self.get_tru_prob(i, j) # true probability that Player i beats Player j
        if winner is None:  # winner = i => Player i
            try:
                winner = np.random.choice([i, j], size = 1, replace = False, p = [p_t, 1 - p_t])
                if winner == i:
                    loser = j
                else:
                    loser = i
            except ValueError:
                print([i, p_t, j, 1-p_t])
                raise ValueError()
        if PRINT:
            print(f"Player {winner} beat Player {loser}")
        
        if record is not None:
            self.hist_M.append(np.max(np.abs(self.get_ratings())))
            self.hist_R.append(self.get_ratings(0))
            
            if record != "M" and (i == 0 or j == 0):
                if i == 0:
                    k = j
                elif j == 0:
                    k = i
                self.hist_P.append(self.get_est_prob(0,k))
        
        # p_e = self.get_est_prob(0, 1) # estimated probability that Player 0 beats Player 1
        # p_e = np.array([p_e, 1-p_e], dtype = float) # estimated probability winner--loser array
        if PRINT:
            print(f"Estimated probability: {self.get_est_prob(winner, loser):.2f}")
        
        if PRINT:
            print(f"Input ratings:  {self.R}")
        increment = self.K * (1 - self.get_est_prob(winner, loser))
        self.R[winner] += increment     # winner adds
        self.R[loser]  -= increment     # loser subtracts
        if record is not None:
            self.hist_err.append(np.linalg.norm(self.get_ratings() - self.tru_rat)**2)
            self.R_time_avg = self.R_time_avg * ((len(self.hist_err) - 1)/len(self.hist_err)) + self.R.copy() / len(self.hist_err)
            self.hist_err_avg.append(np.linalg.norm(np.array(self.R_time_avg) - self.tru_rat)**2)
        if PRINT:
            print(f"Output ratings: {self.R}")
        self.R_max = max(self.R_max, np.abs(self.get_ratings(i)), np.abs(self.get_ratings(j)))
        
        pass
