
import ctypes as ct

################################################################################

REGISTER_CONVERT_CTYPE = \
{
    "void":    None, 
    "int":     ct.c_int,
    "usize":  ct.c_size_t, 
    "f32":   ct.c_float,
    "f64":  ct.c_double,
    "pointer": ct.c_void_p, 
}

def register_ctype_is_pointer(str_type):
    if str_type == "pointer": return True # just in case
    tmp = str_type.split()[-1] # remove useless spaces
    return tmp[-1] == "*"

def register_convert_ctype(str_type):
    """ Convert a ctype specified as a string to ctype.XXX """
    str_type = " ".join(str_type.split()) # clean word
    if register_ctype_is_pointer(str_type): 
        return REGISTER_CONVERT_CTYPE["pointer"]
    return REGISTER_CONVERT_CTYPE[str_type]

def register_rfun(libcfun, specification):
    """ Register a function. 
    - specfication: a string of ctypes of format:
        return_type, var1_type, var2_type, ...
    """
    types = specification.split(",")
    restype  = types[0]
    argtypes = types[1:]
    libcfun.restype  =  register_convert_ctype(restype)
    libcfun.argtypes = [register_convert_ctype(t) for t in argtypes]

################################################################################
# Initializing MDP library

libmdp = ct.cdll.LoadLibrary("../rlib/target/debug/libmdp.so")

register_rfun(libmdp.foo, "usize, usize")
register_rfun(libmdp.MDP_new_tabular, "pointer, usize, usize")
register_rfun(libmdp.MDP_new, "pointer, usize, usize *")
register_rfun(libmdp.MDP_free, "void, pointer")

register_rfun(libmdp.MDP_set_reward, "void, pointer, usize, usize, f64")
register_rfun(libmdp.MDP_set_kernel, "void, pointer, usize, usize, f64 *")
register_rfun(libmdp.MDP_get_kernel_into, "void, pointer, usize, usize, f64 *")
register_rfun(libmdp.MDP_get_reward_into, "void, pointer, usize, usize, f64 *")

register_rfun(libmdp.MDP_sample, "void, pointer, usize, usize, f64 *, usize *")
register_rfun(libmdp.MDP_value_iteration, "void, pointer, f64, f64 *, f64 *, usize *")
register_rfun(libmdp.MDP_invariant_measure, "void, pointer, usize *, f64, f64 *")

if __name__ == "__main__":

    S, A = 5, 2
    model_data = libmdp.MDP_new_tabular(S, A)

    kernel = [1.0 / S for _ in range(S)]
    kernel_data = (ct.c_double * S)()
    for i in range(S): kernel_data[i] = kernel[i]
    libmdp.MDP_set_kernel(model_data, 0, 0, kernel_data)

    libmdp.MDP_set_reward(model_data, 0, 0, 0.33)

    kernel_data = (ct.c_double * S)()
    libmdp.MDP_get_kernel_into(model_data, 0, 0, kernel_data) 
    kernel = [kernel_data[i] for i in range(S)]
    print("Here is the kernel:", kernel)
    
    libmdp.MDP_free(model_data)

