'''functions for policy optimization'''

import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

from jax import value_and_grad
from jax import jacobian

import numpy as np

from scipy.optimize import linear_sum_assignment
from scipy.special import expit, logit
from scipy.stats import bernoulli

from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression
from sklearn.linear_model import LogisticRegression

from sklearn.metrics import mean_squared_error

import data
import utils
import estimation


def subgradient_descent(num_iter=10, step_size=0.1, test_W_samples=None, 
                       W_samples=None, Z_samples=None, Y_samples=None, method='direct',
                       h=1, S=10, fit_outcome_degree=2, mu_degree=2, bootstrap=True,
                       right_node_num = 100):
    
    pol_values = []
    
    if method == 'direct':
        
        estimators_0 = gen_bootstrap_estimators(h=h, S=S,
                       W_samples=W_samples, Z_samples=Z_samples, Y_samples=Y_samples, 
                       outcome_degree=fit_outcome_degree, mu_degree=mu_degree, 
                       method=method, bootstrap=bootstrap)
    
    else:
        
        estimators_0, estimators_1, prop_models = gen_bootstrap_estimators(h=h, S=S,
                       W_samples=W_samples, Z_samples=Z_samples, Y_samples=Y_samples, 
                       outcome_degree=fit_outcome_degree, mu_degree=mu_degree, 
                       method=method, bootstrap=bootstrap)
   
    phi_values = []
    b_values = []
    
    #Initialize random model coefficients
    key = random.PRNGKey(1)
    key, phi_key, b_key = random.split(key, 3)
    phi = random.normal(phi_key, ())
    b = random.normal(b_key, ())
    
    phi_values.append(phi)
    b_values.append(b)
    
    grad_phi = grad(objective_value, argnums=0)
    grad_b = grad(objective_value, argnums=1)
    
    print('initial phi, b: ', phi, b)
    
    for iters in range(num_iter):
        
        print('begin iter: ', iters)
        
        phi_grad_list = []
        b_grad_list = []
        pol_value_list = []
        
        for i, estimator in enumerate(estimators_0):
        
            if method == 'direct':
            
                mu_1_list, mu_0_list = gen_mu_values(test_W_samples=test_W_samples, method=method,
                                   fit_outcome_degree=fit_outcome_degree, mu_degree=mu_degree,
                                    bootstrap_estimators=estimator)
            elif method == 'WDM':
                
                mu_1_list, mu_0_list = gen_mu_values(test_W_samples=test_W_samples, method=method,
                                   fit_outcome_degree=fit_outcome_degree, mu_degree=mu_degree,
                                    bootstrap_estimators=[estimator, estimators_1[i]])
            elif method == 'GRDR':
                
                mu_1_list, mu_0_list = gen_mu_values(test_W_samples=test_W_samples, method=method,
                                   fit_outcome_degree=fit_outcome_degree, mu_degree=mu_degree,
                                    bootstrap_estimators=[estimator, estimators_1[i], prop_models[i]])
                
               
            #compute OPT x
            x_opt, pol_value = compute_x_opt(phi=phi, b=b, test_W_samples=test_W_samples, mu_0_list=mu_0_list,
                             mu_1_list=mu_1_list, right_node_num=right_node_num)
                
            # get gradient
            phi_grad = grad_phi(phi, b, x_array = x_opt, test_W_samples=test_W_samples, 
                                mu_0_list=mu_0_list, mu_1_list=mu_1_list)
            b_grad = grad_b(phi, b, x_array = x_opt, test_W_samples=test_W_samples, 
                                mu_0_list=mu_0_list, mu_1_list=mu_1_list)

            phi_grad_list.append(phi_grad)
            b_grad_list.append(b_grad)
            pol_value_list.append(pol_value)
        
        if len(phi_grad_list) == 1:
            
            phi_grad = phi_grad_list[0]
            b_grad = b_grad_list[0]
            pol_value = pol_value_list[0]
        
        else:
            
            phi_grad = phi_grad_list[0] - 1/float(h) * (phi_grad_list[0] - np.mean(phi_grad_list[1:]))
            b_grad = b_grad_list[0] - 1/float(h) * (b_grad_list[0] - np.mean(b_grad_list[1:]))
            pol_value = pol_value_list[0] - 1/float(h) * (pol_value_list[0] - np.mean(pol_value_list[1:]))
 
        print('policy value: ', pol_value, 'Iter: ', iters)
        
        pol_values.append(pol_value)

        # descent
        phi = phi - step_size * phi_grad
        b = b - step_size * b_grad
        
        print('this iter phi, b: ', phi, b)
        
        phi_values.append(phi)
        b_values.append(b)
        
    results_dict = {'phi_values': phi_values, 'b_values': b_values,
                   'pol_values': pol_values}

    print('all pol_values: ', pol_values)
        
    return results_dict



def compute_x_opt(phi=None, b=None, test_W_samples=None, mu_0_list=None,
                 mu_1_list=None, right_node_num=10):
    
    
    pi_0_list = jnp.array([])
    pi_1_list = jnp.array([])

    for i, test_w_sample in enumerate(test_W_samples):
        
        pi_0_list = jnp.append(pi_0_list, 1 - pi_1(phi, b, test_w_sample))
        pi_1_list = jnp.append(pi_1_list, pi_1(phi, b, test_w_sample))

    
    pi_0_list = jnp.multiply(pi_0_list, mu_0_list)
    pi_1_list = jnp.multiply(pi_1_list, mu_1_list)

    test_sample_costs = jnp.add(pi_0_list, pi_1_list)

    cost_matrix = jnp.transpose(jnp.tile(test_sample_costs,(right_node_num,1)))
    
    min_cost, row_ind, col_ind = utils.compute_matching(cost_matrix)
    
    min_cost = float(min_cost)
    
    x_opt = np.zeros(test_W_samples.shape[0],)
    
    for idx in row_ind:
        x_opt[idx] = 1
    
    return x_opt, min_cost 
    
    

def gen_bootstrap_estimators(h=1, S=10,
                       W_samples=None, Z_samples=None, Y_samples=None, 
                       outcome_degree=2, mu_degree=2, 
                       method='WDM', bootstrap=True):
    
    '''
  
    S: number of draws
    size of draw: N/((1+h)**2)
    
    method: direct, WDM, GRDR
    
    Return: all \hat \mu estimators from bootstrapping.
     
    '''
    
    num_samples = Z_samples.shape[0]
    draw_size = int(num_samples/((1+h)**2))
    
    if method == 'direct':
        
        estimators_list = []
    
    else:
        
        estimators_0_list = []
        estimators_1_list = []
        prop_model_list = []
     
    # Compute mu hat zero
    if method == 'direct':
        
        outcome_model_all = estimation.fit_outcome(W_samples, Z_samples, Y_samples, 
                                        outcome_degree=outcome_degree)
        
        estimators_list.append(outcome_model_all)
            
        
    elif method == 'WDM':
            
        prop_model_all = estimation.fit_prop_score(W_samples=W_samples, Z_samples=Z_samples)
        mu_0_all, mu_1_all = estimation.fit_weighted_direct_mu(Z_samples=Z_samples, W_samples=W_samples, Y_samples=Y_samples, 
                                                      mu_degree=mu_degree, prop_model=prop_model_all)
        
        estimators_0_list.append(mu_0_all)
        estimators_1_list.append(mu_1_all)
            
    elif method == 'GRDR':    

        prop_model_all = estimation.fit_prop_score(W_samples=W_samples, Z_samples=Z_samples)
        mu_0_all, mu_1_all = estimation.fit_robust_mu(Z_samples=Z_samples, W_samples=W_samples, Y_samples=Y_samples, 
                                             mu_degree=mu_degree, prop_model=prop_model_all)

        estimators_0_list.append(mu_0_all)
        estimators_1_list.append(mu_1_all)
        prop_model_list.append(prop_model_all)
        
    # Bootstrap
    if bootstrap == True:
        for num_draw in range(S):

            indices = np.random.choice(num_samples, draw_size, replace=False)

            W_samples_fold = W_samples[indices]
            Z_samples_fold  = Z_samples[indices]
            Y_samples_fold  = Y_samples[indices]

            if method == 'direct':

                outcome_model_fold = estimation.fit_outcome(W_samples_fold, Z_samples_fold, Y_samples_fold, 
                                            outcome_degree=outcome_degree)

                estimators_list.append(outcome_model_fold)


            elif method == 'WDM':

                prop_model_fold = estimation.fit_prop_score(W_samples=W_samples_fold, Z_samples=Z_samples_fold)
                mu_0_fold, mu_1_fold = estimation.fit_weighted_direct_mu(Z_samples=Z_samples_fold, W_samples=W_samples_fold, Y_samples=Y_samples_fold, 
                                                              mu_degree=mu_degree, prop_model=prop_model_fold)

                estimators_0_list.append(mu_0_fold)
                estimators_1_list.append(mu_1_fold)

            elif method == 'GRDR':    

                prop_model_fold = estimation.fit_prop_score(W_samples=W_samples_fold, Z_samples=Z_samples_fold)
                mu_0_fold, mu_1_fold = estimation.fit_robust_mu(Z_samples=Z_samples_fold, W_samples=W_samples_fold, Y_samples=Y_samples_fold, 
                                                     mu_degree=mu_degree, prop_model=prop_model_fold)

                estimators_0_list.append(mu_0_fold)
                estimators_1_list.append(mu_1_fold)
                prop_model_list.append(prop_model_fold)
       
    if method == 'direct':
        return estimators_list
    else:
        return estimators_0_list, estimators_1_list, prop_model_list



def objective_value(phi, b, x_array=None, test_W_samples=None, 
                    mu_0_list=None, mu_1_list=None):
    
    '''
    
    Used for policy optimization.
    Compute obj value for a given policy, a given x, given estimators
    
    '''
    
    pi_0_list = jnp.array([])
    pi_1_list = jnp.array([])

    
    for i, test_w_sample in enumerate(test_W_samples):
        
        pi_0_list = jnp.append(pi_0_list, 1 - pi_1(phi, b, test_w_sample))
        pi_1_list = jnp.append(pi_1_list, pi_1(phi, b, test_w_sample))

    
    pi_0_list = jnp.multiply(pi_0_list, mu_0_list)
    pi_1_list = jnp.multiply(pi_1_list, mu_1_list)

    test_sample_costs = jnp.add(pi_0_list, pi_1_list)
    
    obj_value = jnp.dot(test_sample_costs, x_array)

    return obj_value


def gen_mu_values(test_W_samples=None, method='WDM',
               fit_outcome_degree=2, mu_degree=2,
              bootstrap_estimators=None):
    
    
    mu_1_list = []
    mu_0_list = []
    
    for i, test_w_sample in enumerate(test_W_samples):
 
        if method == 'direct':
                        
                mu_1_list.append(estimation.direct_mu(Z=1, W=test_w_sample, theta_reg_model=bootstrap_estimators, 
                              degree=fit_outcome_degree))


                mu_0_list.append(estimation.direct_mu(Z=0, W=test_w_sample, theta_reg_model=bootstrap_estimators, 
                              degree=fit_outcome_degree))

        elif method == 'WDM':
                            
                mu_0_list.append(estimation.weighted_direct_mu(W=test_w_sample, Z=0,
                                                mu_0_model=bootstrap_estimators[0], mu_degree=mu_degree))
                    
                mu_1_list.append(estimation.weighted_direct_mu(W=test_w_sample, Z=1,
                                                mu_1_model=bootstrap_estimators[1], mu_degree=mu_degree))

           
        elif method == 'GRDR':
                            
                mu_0_list.append(estimation.robust_mu(W=test_w_sample, Z=0, mu_0_model=bootstrap_estimators[0], 
                                       mu_degree=mu_degree, prop_model=bootstrap_estimators[2]))
                
                mu_1_list.append(estimation.robust_mu(W=test_w_sample, Z=0, mu_1_model=bootstrap_estimators[1], 
                                       mu_degree=mu_degree, prop_model=bootstrap_estimators[2]))
        
    
    return np.array(mu_1_list), np.array(mu_0_list)


def sigmoid(x):
    return 0.5 * (jnp.tanh(x / 2) + 1)


def pi_1(phi, b, inputs):
    return sigmoid(jnp.dot(inputs, phi) + b)





