mod mdp;
use std::slice;
use std::ptr;

#[no_mangle]
fn foo (x:usize) -> usize { 2*x }

#[no_mangle]
fn MDP_new_tabular(n_s:usize, n_a:usize) -> *mut mdp::MDP {
    
    let boxed_mdp = Box::new(mdp::MDP::new_tabular(n_s, n_a));
    Box::into_raw(boxed_mdp)
}

#[no_mangle]
unsafe fn MDP_new(n_s:usize, n_a_ptr: *const usize) -> *mut mdp::MDP {

    let n_states  = n_s;
    let n_actions = slice::from_raw_parts(n_a_ptr, n_states).to_vec();
    let boxed_mdp = Box::new(mdp::MDP::new(n_states, n_actions));
    Box::into_raw(boxed_mdp)
}

#[no_mangle]
unsafe fn MDP_free(data: *mut mdp::MDP) {
    
    let boxed_mdp = Box::from_raw(data);
    drop(boxed_mdp) // unnecessary but explicit
}

#[no_mangle]
unsafe fn MDP_set_reward(data: *mut mdp::MDP, x:usize, a:usize, mu:f64) {
    
    (*data).set_reward(x, a, mu)
}

#[no_mangle]
unsafe fn MDP_set_kernel(data: *mut mdp::MDP, x:usize, a:usize, ker_ptr: *const f64) {

    let n_states = (*data).get_state_number(); 
    let ker_vec  = slice::from_raw_parts(ker_ptr, n_states).to_vec();
    (*data).set_kernel(x, a, & ker_vec)
}

#[no_mangle]
unsafe fn MDP_get_kernel_into(
    data: *mut mdp::MDP, 
    x:usize, 
    a:usize,
    dst_ptr: *mut f64) {

    let ker = (*data).get_kernel(x, a);
    let ker_ptr = ker.as_ptr();
    let ker_len = ker.len();
    ptr::copy_nonoverlapping(ker_ptr, dst_ptr, ker_len)
}

#[no_mangle]
unsafe fn MDP_get_reward_into(
    data: *mut mdp::MDP,
    x:usize,
    a:usize,
    dst_ptr: *mut f64) {
    
    let rew = (*data).get_reward(x, a);
    *dst_ptr = rew
}

#[no_mangle]
unsafe fn MDP_sample(
    data: *mut mdp::MDP,
    x:usize,
    a:usize,
    rew_ptr: *mut f64,
    sta_ptr: *mut usize) {

    let (r, y) = (*data).sample(x, a);
    *rew_ptr = r;
    *sta_ptr = y
}

#[no_mangle]
unsafe fn MDP_value_iteration(
    data: *const mdp::MDP,
    eps: f64,
    u_data: *mut f64,
    g_data: *mut f64,
    pi_data: *mut usize) {

    let n_states = (*data).get_state_number();
    let u = slice::from_raw_parts(u_data, n_states).to_vec();

    let (u, g, pi) = (*data).value_iteration(eps, Some(u));
    
    let u_ptr  = u.as_ptr();
    let g_ptr  = g.as_ptr();
    let pi_ptr = pi.as_ptr();
    ptr::copy_nonoverlapping( u_ptr,  u_data, n_states);
    ptr::copy_nonoverlapping( g_ptr,  g_data, n_states);
    ptr::copy_nonoverlapping(pi_ptr, pi_data, n_states)
}

#[no_mangle]
unsafe fn MDP_invariant_measure(
    data: *const mdp::MDP,
    pi_data: *const usize,
    eps: f64,
    mu_data: *mut f64) {
    
    let n_states = (*data).get_state_number();
    let pi = slice::from_raw_parts(pi_data, n_states).to_vec();
    let mu = (*data).invariant_measure(pi, eps);
    let mu_ptr = mu.as_ptr();
    ptr::copy_nonoverlapping(mu_ptr, mu_data, n_states)
}
