use rand::prelude::*;
use rand_distr::{Binomial, Distribution, WeightedAliasIndex};

#[derive(Debug)]
pub struct MDP {
    n_states     : usize,
    n_actions    : Vec<usize>,
    rewards : Vec<Vec<f64>>,
    kernels : Vec<Vec<Vec<f64>>>,
    r_samplers: Vec<Vec<Option<Binomial>>>, 
    k_samplers: Vec<Vec<Option<WeightedAliasIndex<f64>>>>,
}

impl MDP {
    
    pub fn new_tabular(n_states: usize, n_actions: usize) -> Self {
       
        MDP::new(n_states, (0..n_states).map(|_| n_actions).collect())
    }

    pub fn new(n_states: usize, n_actions_vec: Vec<usize>) -> Self {
        
        let rewards     : Vec<Vec<f64>> = n_actions_vec
            .iter()
            .map(|&n_actions| vec![0.0; n_actions])
            .collect()
        ;
        let mut kernels : Vec<Vec<Vec<f64>>> = n_actions_vec
            .iter()
            .map(|&n_actions| vec![vec![0.0; n_states]; n_actions])
            .collect()
        ;
        for x in 0..n_states {
            for a in 0..n_actions_vec[x] {
                kernels[x][a][0] = 1.0;
            }
        };
        let r_samplers : Vec<Vec<Option<Binomial>>> = n_actions_vec
            .iter()
            .map(|&n_actions| (0..n_actions).map(|_| None).collect())
            .collect()
        ;
        let k_samplers : Vec<Vec<Option<WeightedAliasIndex<f64>>>> = n_actions_vec
            .iter()
            .map(|&n_actions| (0..n_actions).map(|_| None).collect())
            .collect()
        ;
        MDP {
            n_states: n_states,
            n_actions: n_actions_vec,
            rewards: rewards,
            kernels: kernels,
            r_samplers: r_samplers,
            k_samplers: k_samplers,
        }

    }

    // Setters

    pub fn set_kernel(&mut self, x:usize, a:usize, kernel:&Vec<f64>,) {
        
        self.k_samplers[x][a] = None;
        for y in 0..self.n_states {
            self.kernels[x][a][y] = kernel[y]
        }
    }

    pub fn set_reward(&mut self, x:usize, a:usize, mu:f64,) {

        self.r_samplers[x][a] = None;
        self.rewards[x][a] = mu;
    }

    // Getters
    
    pub fn get_state_number(&self) -> usize {
        
        self.n_states
    }
    
    pub fn get_kernel(&self, x:usize, a:usize) -> Vec<f64> {
        
        self.kernels[x][a].clone()
    }

    pub fn get_reward(&self, x:usize, a:usize) -> f64 {
        
        self.rewards[x][a]
    }

    fn get_reward_sampler(&mut self, x:usize, a:usize) -> & Binomial {

        if self.r_samplers[x][a].is_none() {
            let mu        = self.rewards[x][a];
            let r_sampler = Binomial::new(1, mu).unwrap();
            self.r_samplers[x][a] = Some(r_sampler);
        }
        self.r_samplers[x][a].as_ref().unwrap()
    }

    fn get_kernel_sampler(&mut self, x:usize, a:usize) -> & WeightedAliasIndex<f64> {

        if self.k_samplers[x][a].is_none() {
            let weights = self.kernels[x][a].clone();
            let k_sampler = WeightedAliasIndex::new(weights).unwrap();
            self.k_samplers[x][a] = Some(k_sampler);
        }
        self.k_samplers[x][a].as_ref().unwrap()
    }

    // Sampling
    
    pub fn sample(&mut self, x:usize, a:usize) 
    -> (f64, usize) {
        
        let mut rng = thread_rng();
        let r = self.get_reward_sampler(x, a).sample(&mut rng) as f64;
        let y = self.get_kernel_sampler(x, a).sample(&mut rng);
        (r, y)
    }

    // Value Iteration

    pub fn value_iteration(&self, eps:f64, u:Option<Vec<f64>>) 
    -> (Vec<f64>, Vec<f64>, Vec<usize>) {
        
        let mut u = match u {
            Some(v) => v,
            None => (0..self.n_states).map(|_| 0.0).collect(),
        };
        
        let mut g : Vec<f64> = (0..self.n_states).map(|_| 0.0).collect();
        let mut pi_opt : Vec<usize> = (0..self.n_states).map(|_| 0).collect();
        for iter in 0..10000 {
            let mut v = u.clone();
            for x in 0..self.n_states {
                for a in 0..self.n_actions[x] {
                    let dotp : f64 = self.kernels[x][a]
                        .iter()
                        .enumerate()
                        .map(|(y, p)| p * u[y])
                        .sum()
                    ;
                    let q_xa = 0.5 * (self.rewards[x][a] + u[x] + dotp);
                    if a == 0 || q_xa > v[x] {
                        v     [x] = q_xa;
                        pi_opt[x] = a;
                    }
                }
            }
            let du : Vec<f64> = v.iter().zip(u.iter())
                .map(|(v_x, u_x)| v_x - u_x).collect();
            let max_du = du.iter().max_by(|a, b| a.total_cmp(b)).unwrap();
            let min_du = du.iter().min_by(|a, b| a.total_cmp(b)).unwrap();
            let sp     = max_du - min_du;
            for x in 0..self.n_states {
                g[x] = 2.0 * (v[x] - u[x]);
                u[x] = v[x] - v[0];
            }

            if sp < eps { break }
        }

        for x in 0..self.n_states { u[x] *= 1.0; }
        return (u, g, pi_opt)
    }

    pub fn invariant_measure(&self, pi: Vec<usize>, eps: f64) -> Vec<f64> {
        
        let n_states = self.n_states;
        let mut pi_model = MDP::new_tabular(n_states, 1);
        for x in 0..n_states {
            pi_model.set_kernel(x, 0, & self.kernels[x][pi[x]])
        }
        
        let mut measure : Vec<f64> = Vec::new();
        for x_0 in 0..n_states {
            for x in 0..n_states { // Prepare Model
                let dirac = if x_0 == x { 1.0 } else { 0.0 };
                pi_model.set_reward(x, 0, dirac)
            }
            let u : Vec<f64> = (0..n_states).map(|_| 0.0).collect();
            let (_, g, _) = pi_model.value_iteration(eps, Some(u));
            measure.push(g[x_0])
        }
        measure
    }
}
