
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import transition_data as transition_data
import cvxopt
import cvxpy as cp 


def infinite_get_true_solns(policy, env, gamma): 
  mu, rewards, transitions = env._get_matrices(policy)
  dim = len(mu)

  qmatrix = np.eye(dim) - gamma * transitions.transpose() 
  wmatrix = np.eye(dim) - gamma * transitions

  q_pi = np.linalg.solve(qmatrix, rewards)
  d_pi = np.linalg.solve(wmatrix, (1. - gamma) * mu)
  target_rewards = np.dot(d_pi, rewards) 

  return d_pi, q_pi, target_rewards
  

def finite_get_true_solns(policy, env, gamma, max_trajectory_length): 
  mu, rewards, transitions = env._get_matrices(policy)
  finite_qmatrix_sum = np.array([np.linalg.matrix_power(gamma * transitions.transpose(), t) for t in range(max_trajectory_length)])
  finite_qmatrix = np.sum(finite_qmatrix_sum, axis=0)
  finite_q_pi = np.dot(finite_qmatrix, rewards)

  normalization = np.sum(np.power(gamma, range(max_trajectory_length)))
  finite_wmatrix_sum = np.array([np.linalg.matrix_power(gamma * transitions, t) for t in range(max_trajectory_length)])
  finite_wmatrix = np.sum(finite_wmatrix_sum, axis=0)
  finite_d_pi = np.dot(finite_wmatrix, mu) / normalization
  
  finite_target_rewards = np.dot(finite_d_pi, rewards) 

  return finite_d_pi, finite_q_pi, finite_target_rewards


def zeros_divide(a, b): 
  return np.divide(a, b, out = np.zeros_like(a), where=b!=0)

def compute_l2_error(f, true, dist): 
    normalization = np.sqrt(np.dot(dist.flatten(), np.square(true.flatten())))
    return np.sqrt(np.dot(dist.flatten(), np.square(f - true.flatten()))) / normalization

def compute_ope_error(f, true): 
    return np.abs(f - true) / np.abs(true)

def generate_mapping(func, indices):
  num_states = len(func)
  mapping = np.zeros([num_states, len(indices)])

  for i in range(num_states): 
    diff = np.abs(func[i] - indices) 
    mapped_idx = np.argmin(diff) 
    mapping[i, mapped_idx] = 1 
  
  return mapping 



def get_index(state, action):
    return state * 4 + action

def true_linear_system(env, policy, gamma, d_D, q_phi, g_phi): 
  mu_pi, rewards_pi, transitions_pi = env._get_matrices(policy)
  d = q_phi.shape[1]
  k = g_phi.shape[1] 
  true_lhs = np.zeros([k, d])
  true_rhs = np.zeros([k])
  for idx in range(len(mu_pi)): 
    g_feat = g_phi[idx]
    q_feat = q_phi[idx]
    
    trans_sa = transitions_pi.transpose()[idx]
    next_q_feat = np.dot(trans_sa, q_phi)
    true_lhs += (np.outer(g_feat, q_feat) - gamma * np.outer(g_feat, next_q_feat)) * d_D[idx]
    true_rhs += (rewards_pi[idx] * g_feat) * d_D[idx]
  return true_lhs, true_rhs 


def null_space(A, rcond=None):
    u, s, vh = np.linalg.svd(A, full_matrices=True)
    M, N = u.shape[0], vh.shape[1]
    if rcond is None:
        rcond = np.finfo(s.dtype).eps * max(M, N)
    tol = np.amax(s) * rcond
    num = np.sum(s > tol, dtype=int)
    Q = vh[num:,:].T.conj()
    return Q

def cp_optimize(A, b, dist, model): 
    (k, d) = A.shape
    x = cp.Variable(d)
    if model is None: 
      obj = cp.quad_form(x, dist)
    else: 
      obj = cp.quad_form(x - model, dist)
    constraints = [A @ x == b]
    prob = cp.Problem(cp.Minimize(obj), constraints)
    prob.solve(solver=cp.SCS) 
    return x.value

 
def estimate_value(initial_states, target_policy, gamma, q_fn): 
  num_actions = 4
  inp = np.concatenate((np.tile(initial_states, num_actions)[:, None], 
                np.tile(range(num_actions), len(initial_states))[:, None]), 
                axis = -1)
  probs = target_policy.get_probability(inp[:, 0], inp[:, 1])
  qvalues = q_fn[get_index(inp[:, 0], inp[:, 1])] * probs
  return (1-gamma) * np.sum(qvalues) / len(initial_states)

def get_wf(q_pi, d_D, policy, env, gamma, nu, model=0.,): 
  mu, rewards, transitions = env._get_matrices(policy)
  dim = len(mu)
  mask = d_D > 0

  wmatrix = np.eye(dim) - gamma * transitions
  
  d_D_wf_star = np.dot(np.linalg.inv(wmatrix), nu * (q_pi - model)) 
  wf_star = zeros_divide(d_D_wf_star, d_D)

  return wf_star

def get_d_s(dist, grid_length): 
  d_s =  np.array([np.sum(dist[i*4:(i+1)*4]) for i in range(grid_length **2)])
  return d_s
