import numpy as np

def initial_lips_bound(SASR, policy_target, feature_fn, gamma, eta, repeat = 5, discrete_action = False):
    S, A, SN, REW = SASR
    # input shape should be (N * dim) for S, A, SN, and (N) for REW
    N = REW.shape[0]
    F = feature_fn(S,A)         # (N * f_dim)
    D = np.zeros_like(REW)
    if discrete_action == False:
        # For continuous action case, we estimate via Monte Carlo
        action_dim = policy_target.action_dim
        for j in range(repeat):
            AN = policy_target.choose_actions(SN).reshape(-1, action_dim)
            FN = feature_fn(SN, AN)
            D += np.sqrt(np.sum((F - FN) ** 2, axis = -1))  # (N)
        D /= repeat
    else:
        # For discrete action case, we enumerate all actions and calculate Bpi exactly with the probability
        for j in range(policy_target.shape[-1]):
            AN = np.zeros_like(policy_target)
            AN[:,j] = 1
            FN = feature_fn(SN, AN)
            D += policy_target[:,j] * np.sqrt(np.sum((F - FN) ** 2, axis = -1))
    Q_lower = (REW - gamma * eta * D) / (1 - gamma)
    Q_upper = (REW + gamma * eta * D) / (1 - gamma)
    return Q_lower, Q_upper


def Fitted_update(Q_lower, Q_upper, SASR, policy_target, feature_fn, gamma, eta, subsample_size = -1, double_sample = False, repeat = 5, seed = 43, discrete_action = False):
    np.random.seed(seed)
    S, A, SN, REW = SASR
    # input shape should be (N * dim) for S, A, SN, and (N) for REW
    N = REW.shape[0]

    if subsample_size == -1 or subsample_size > N:
        subsample_size = N
    perm = np.random.permutation(N)
    index = perm[:subsample_size]

    if double_sample:
        index2 = index
    else:
        index2 = np.arange(N)


    SN = SN[index]
    F = feature_fn(S[index2], A[index2])

    BQ_lower = np.zeros([subsample_size])
    BQ_upper = np.zeros([subsample_size])

    if discrete_action == False:
        action_dim = policy_target.action_dim
        for i_repeat in range(repeat):
            AN = policy_target.choose_actions(SN).reshape(-1, action_dim)
            FN = feature_fn(SN, AN)

            dist = np.sqrt(np.sum((F[None,:,:] - FN[:,None,:]) ** 2, axis = -1))   # (subsample_size * K), K = N but if double_sample, K = subsample_size
            BQ_lower += np.max(Q_lower[index2] - eta * dist, axis = -1)                        # (subsample_size)
            BQ_upper += np.min(Q_upper[index2] + eta * dist, axis = -1)
        BQ_lower /= repeat
        BQ_upper /= repeat
    else:
        policy_target = policy_target[index]
        for i in range(policy_target.shape[-1]):
            AN = np.zeros_like(policy_target)
            AN[:,i] = 1
            FN = feature_fn(SN, AN)
            dist = np.sqrt(np.sum((F[None,:,:] - FN[:,None,:]) ** 2, axis = -1))
            BQ_lower += policy_target[:,i] * np.max(Q_lower[index2] - eta * dist, axis = -1)
            BQ_upper += policy_target[:,i] * np.min(Q_upper[index2] + eta * dist, axis = -1)

    # Q_lower[index] = REW[index] + gamma * BQ_lower
    Q_lower[index] = np.maximum(Q_lower[index], REW[index] + gamma * BQ_lower)
    # Q_upper[index] = REW[index] + gamma * BQ_upper
    Q_upper[index] = np.minimum(Q_upper[index], REW[index] + gamma * BQ_upper)
    return Q_lower, Q_upper

def estimate_rpi(s0, policy_target, feature_fn, S, A, Q_lower, Q_upper, gamma, eta, repeat = 5, seed = 43, discrete_action = False):
    np.random.seed(seed)

    F = feature_fn(S, A)
    Q0_lower = 0
    Q0_upper = 0
    s0 = s0.reshape(-1, S.shape[-1])

    if discrete_action == False:
        for i_repeat in range(repeat):
            a0 = policy_target.choose_actions(s0)
            F0 = feature_fn(s0, a0)
            dist = np.sqrt(np.sum((F0[:,None,:] - F[None,:,:]) ** 2, axis = -1))
            Q0_lower += np.mean(np.max(Q_lower - eta * dist, axis = -1))
            Q0_upper += np.mean(np.min(Q_upper + eta * dist, axis = -1))

        Q0_lower /= repeat	# ()
        Q0_upper /= repeat	# ()
    else:
        for i in range(policy_target.shape[-1]):
            a0 = np.zeros_like(policy_target)
            a0[:,i] = 1
            F0 = feature_fn(s0, a0)
            dist = np.sqrt(np.sum((F0[:,None,:] - F[None,:,:]) ** 2, axis = -1))
            Q0_lower += np.mean(policy_target[:,i] * np.max(Q_lower - eta * dist, axis = -1))
            Q0_upper += np.mean(policy_target[:,i] * np.min(Q_upper + eta * dist, axis = -1))

    return (1-gamma) * Q0_lower, (1-gamma) * Q0_upper

def lips_bound_evaluation(s0, replay_buffer, policy_target, feature_fn, gamma, eta, subsample_size = -1, double_sample = False, max_iteration = 50, pi0 = None, discrete_action = False):
    S, A, SN, REW = replay_buffer			# flat
    # S(states): N * state_dim
    # A(actions): N * action_dim
    # SN(next states): N * state_dim
    # REW: N
    # feature_fn: Feature function to calculate distance, if use l2, default as np.hstack([S,A])

    N = S.shape[0]
    # np.random.seed(44)

    Q_lower, Q_upper = initial_lips_bound(replay_buffer, policy_target, feature_fn, gamma, eta, discrete_action = discrete_action)
    for iter in range(max_iteration):
        Q_lower, Q_upper = Fitted_update(Q_lower, Q_upper, replay_buffer, policy_target, feature_fn, gamma, eta, subsample_size = subsample_size, double_sample = double_sample, seed = iter, discrete_action = discrete_action)
        if iter % 10 == 0:
            if discrete_action:
                Q0_lower, Q0_upper = estimate_rpi(s0, pi0, feature_fn, S, A, Q_lower, Q_upper, gamma, eta, discrete_action = discrete_action)
            else:
                Q0_lower, Q0_upper = estimate_rpi(s0, policy_target, feature_fn, S, A, Q_lower, Q_upper, gamma, eta, discrete_action = discrete_action)
            print('iter = {}, lower = {}, upper = {}'.format(iter, Q0_lower, Q0_upper))
    return Q_lower, Q_upper
