;; gorilla-repl.fileformat = 1

;; @@

(ns model
  (:require [gorilla-plot.core :as plot])
  (:use [anglican core emit runtime stat
          [state :only [get-predicts get-log-weight get-result]]]))

(defdist geometric
"Geometric distribution on support {0,1,2....}"
[p] []
(sample* [this]
        (loop [value 0]
            (if (sample* (flip p))
            value
            (recur (inc value)))))
(observe* [this value] (+ (log p) (* value (log (- 1 p))))))

(defdist dirac [x]
    (sample* [this] x)
    (observe* [this value]
              (if (= value x)
                0
                (- (/ 1.0 0.0)))))


(with-primitive-procedures [dirac geometric]
  (defquery model [method- options- ]
    (let [[ a b] [ 0 0 ]
          [ a b]
          (let [
                a (sample (poisson 46.260000000000005))
                b (sample (poisson 5.140000000000001))
                b (+ b (sample (binomial a 0.1)))
                a (sample (binomial a 0.23724))
                b (sample (binomial b 0.2636))
                a (+ a (sample (poisson 209.34)))
                b (+ b (sample (poisson 23.26)))
                _unused (observe (binomial a 0.2) 35)
                _unused (observe (binomial b 0.2) 3)
                b (+ b (sample (binomial a 0.1)))
                a (sample (binomial a 0.23724))
                b (sample (binomial b 0.2636))
                a (+ a (sample (poisson 378.72)))
                b (+ b (sample (poisson 42.080000000000005)))
                _unused (observe (binomial a 0.2) 83)
                _unused (observe (binomial b 0.2) 6)
                b (+ b (sample (binomial a 0.1)))
                a (sample (binomial a 0.23724))
                b (sample (binomial b 0.2636))
                a (+ a (sample (poisson 270.72)))
                b (+ b (sample (poisson 30.080000000000002)))
                _unused (observe (binomial a 0.2) 78)
                _unused (observe (binomial b 0.2) 10)
                b (+ b (sample (binomial a 0.1)))
                a (sample (binomial a 0.23724))
                b (sample (binomial b 0.2636))
                a (+ a (sample (poisson 270.72)))
                b (+ b (sample (poisson 8.56)))
                _unused (observe (binomial a 0.2) 58)
                _unused (observe (binomial b 0.2) 4)
               ]
            [ a b ]
          )
         ]
    b
    )
  )

)


(def model_name "two_populations2000")
(def outfile "output/two_populations2000_anglican.json")

; (def configurations [:rmh []])
(def configurations
  [
    [:importance []]
    [:lmh []]
    [:rmh []]
    [:rmh [:alpha 0.8 :sigma 2]]
    [:smc []]
    [:smc [:number-of-particles 1000]]
    [:pgibbs []]
    [:pgibbs [:number-of-particles 1000]]
    [:ipmcmc []]
    [:ipmcmc [:number-of-particles 1000 :number-of-nodes 32]]
  ])

; (def num_samples_options [1000])
(def num_samples_options [1000 10000])
(def thinning 1)

(spit outfile "[\n" :append false)

(def num-chains 20)

(doall
  (for [ num_samples num_samples_options
         [method options] configurations
         chain (range 0 num-chains)]
    (do
      (println (format "\nMethod %s with %s samples and options %s" method num_samples options))
      (println (format "Chain no. %s" chain))
      (let [start (. System (nanoTime))
            warmup (/ num_samples 5)
            samples (take-nth thinning (take (* num_samples thinning) (drop warmup (apply doquery method model [method options] options))))
            results (collect-results samples)
            values (map (fn [s] (get-result s)) samples)
            max-value (apply max values)
            mean (empirical-mean results)
            variance (empirical-variance results)
            std (empirical-std results)
            skewness (if (zero? std) (/ 0.0 0.0) (empirical-skew results))
            kurtosis (if (zero? std) (/ 0.0 0.0) (empirical-kurtosis results))
            distribution (empirical-distribution (collect-results samples))
            masses (for [n (range 0 (inc max-value))] (get distribution n 0.0))
            end (. System (nanoTime))
            elapsed_ms (/ (- end start) 1e6)]
        (println (format "Elapsed time: %s ms" elapsed_ms))
        (println (format "Empirical mean: %s" mean))
        (println (format "Empirical variance: %s" variance))
        (println (format "Empirical std: %s" std))
        (println (format "Empirical skewness: %s" skewness))
        (println (format "Empirical kurtosis: %s" kurtosis))
        (spit outfile (format
                   "{\"model\": \"%s\", \"system\": \"anglican\", \"method\": \"%s\", \"options\": \"%s\", \"num_samples\": %s, \"time_ms\": %s, \"total\": 1.0, \"mean\": %s, \"variance\": %s, \"stddev\": %s, \"skewness\": %s, \"kurtosis\": %s, \"masses\": [%s] },\n"
                   model_name method options num_samples elapsed_ms mean variance std skewness kurtosis
                   (clojure.string/join ", " masses)) :append true)
        (if false (do
          (println "Empirical distribution:")
          (doall (for [n (range 0 (inc max-value))]
            (println (format "p(%s) = %s" n (get distribution n 0.0)))))))
        ; (println "List of samples (format: sample log-weight):")
        ; (doall (map (fn [s] (println (format "%s %s" (get-result s) (get-log-weight s)))) samples))
        ; values need to be adjusted if they are weighted!
        ; (plot/histogram values :normalize :probability)
      )
    )
  )
)

(spit outfile "]\n" :append true)




;; @@
