import numpy as np


def proba_matrix(n, m):
    matrix = np.random.random(size=(n, m))
    matrix = matrix / np.sum(matrix, axis=1, keepdims=True)
    return matrix


def det_proba_matrix(n, m):
    matrix = np.zeros(shape=(n, m))
    for i in range(n):
        ind = np.random.randint(m)
        matrix[i, ind] = 1
    return matrix


def block_proba_matrix(n, m):
    matrix = np.zeros(shape=(n, m))
    sep = np.random.choice(m-1, n-1, replace=False)
    sep = np.sort(sep)
    sep = list(sep)
    sep.append(m-1)
    j = 0
    for i in range(m):
        if i > sep[j]:
            j += 1
        matrix[j, i] = np.random.random()
    return matrix / matrix.sum(axis=1, keepdims=True)


def random_dist(n):
    dist = np.random.random(size=(n,))
    return dist / dist.sum()


def project(v):
    """v array of shape (n_features, n_samples)."""
    v = v.reshape(-1, 1)
    z = 1
    p, n = v.shape
    u = np.sort(v, axis=0)[::-1, ...]
    pi = np.cumsum(u, axis=0) - z
    ind = (np.arange(p) + 1).reshape(-1, 1)
    mask = (u - pi / ind) > 0
    rho = p - 1 - np.argmax(mask[::-1, ...], axis=0)
    theta = pi[tuple([rho, np.arange(n)])] / (rho + 1)
    w = np.maximum(v - theta, 0)
    return w.flatten()


def sample(dist, argmax=False):
    if not argmax:
        return np.random.choice(dist.shape[0], p=dist)
    else:
        return np.argmax(dist)


class POMDP:
    def __init__(self, s_num, a_num, o_num, H, det=False, block=False):
        self.s_space = np.array(list(range(s_num)))
        self.a_space = np.array(list(range(a_num)))
        self.o_space = np.array(list(range(o_num)))
        self.s_num = s_num
        self.a_num = a_num
        self.o_num = o_num
        self.H = H
        self.trans = []
        for a in self.a_space:
            if not det:
                self.trans.append(proba_matrix(s_num, s_num))
            else:
                self.trans.append(det_proba_matrix(s_num, s_num))
        if not block:
            self.emiss = block_proba_matrix(s_num, o_num)
        else:
            self.emiss = block_proba_matrix(s_num, o_num)
        self.reward = []
        for i in range(H):
            self.reward.append(np.random.random(size=(s_num, a_num)))
        if not det:
            self.mu = random_dist(s_num)
        else:
            self.mu = np.zeros(shape=(s_num, ))
            self.mu[np.random.randint(s_num)] = 1
        self.curr_step = 0
        self.state = sample(self.mu)
        self.o = sample(self.emiss[self.state])

    def reset(self):
        self.state = sample(self.mu)
        self.o = sample(self.emiss[self.state])
        self.curr_step = 0
        return self.state, self.o

    def step(self, a):
        r = self.reward[self.curr_step][self.state, a]
        self.curr_step += 1
        self.state = sample(self.trans[a][self.state])
        o = sample(self.emiss[self.state])
        done = False
        if self.curr_step == self.H:
            done = True
        return self.state, o, r, done

    def posterior_o(self, b, o):
        for state in range(self.s_num):
            b[state] = b[state] * self.emiss[state, o]
        if b.sum() < 0.01:
            return np.ones(shape=(self.s_num, ))/self.s_num
        return b / b.sum()

    def posterior_a(self, b, a):
        b = self.trans[a].T @ (b.reshape(-1, 1))
        return b.flatten()


if __name__ == '__main__':
    det_proba_matrix(10, 10)
    block_proba_matrix(10, 20)
    p = POMDP(3, 3, 3, 10)
    for i in range(100):
        print(p.step(2))
