(require [hy.contrib.walk [let]])

(import jax
        [jax.numpy :as jnp]
        [jax.experimental.stax :as stax]
        [neural_tangents :as nt]
        [neural_tangents [stax :as nt-stax]]
        [jax.experimental.optimizers :as optimizers]
        [jax.flatten_util [ravel_pytree]]
        [numpy :as np]
        [matplotlib.pyplot :as plt]
        [tqdm [tqdm trange]]
        [sklearn.model_selection [train_test_split]]
        [toolz.dicttoolz [merge]]
        [math [ceil]]
        [nn_utilities :as nn_utils]
        os
        pickle)

(defmacro bound? [x]
  `(try ~x
        (except [NameError] False)
        (else True)))

(defmacro default [x d]
  `(if (bound? ~x) ~x ~d))

(import [sklearn.preprocessing [normalize]])

(defn partial-flatten [x]
  (np.reshape x (, (get (np.shape x) 0) -1)))

(defn mnist-data [[train-set "vanilla"] [test-set "vanilla"] [conv False]]
  (setv
    train-images (np.expand-dims (np.squeeze (with [f (open f"../../mnist_c/{train-set}/train_images.npy" "rb")] (np.load f))) 3)
    train-labels (with [f (open f"../../mnist_c/{train-set}/train_labels.npy" "rb")] (np.load f))
    test-images (np.expand-dims (np.squeeze (with [f (open f"../../mnist_c/{test-set}/test_images.npy" "rb")] (np.load f))) 3)
    test-labels (with [f (open f"../../mnist_c/{test-set}/test_labels.npy" "rb")] (np.load f))
    train-images (/ (if conv train-images (partial-flatten train-images)) (np.float32 255))
    test-images (/ (if conv test-images (partial-flatten test-images)) (np.float32 255)))
  (, train-images test-images train-labels test-labels))

(defn mnist-train-net [train-net input-shape [conv False] [optimizer None]]
  (setv
    [net-init net-apply] (stax.serial (unpack-iterable train-net))
    net-apply (jax.jit net-apply)
    [opt-init opt-update opt-get] (if (is optimizer None)
                                      (optimizers.sgd :step-size #_(optimizers.piecewise-constant [1500] [1e-1 1e-2])
                                                      (optimizers.exponential-decay
                                                        :step-size 1e-1
                                                        :decay-rate 0.99995
                                                        :decay-steps 1))
                                      (hy.eval optimizer))
    calc-loss (jax.jit (fn [p x y [rng None]] (nn-utils.ce-with-logits-loss (net-apply p x :rng rng) y)))
    new-opt-state (fn [rng] (opt-init (get (net-init rng (if conv
                                                             (, -1 input-shape input-shape 1)
                                                             (, -1 input-shape))) 1))))
  (, net-apply calc-loss opt-update opt-get new-opt-state))

(defn mnist-test-net [test-net]
  (setv net-apply (jax.jit (get (stax.serial (unpack-iterable test-net)) 1))
        calc-loss (jax.jit (fn [p x y [rng None]] (nn-utils.ce-with-logits-loss (net-apply p x :rng rng) y))))
  (, net-apply calc-loss))

(setv num-outputs 10
      net
      [(stax.Dense 300)
       stax.Relu
       (stax.Dense 300)
       stax.Relu
       (stax.Dense 300)
       stax.Relu
       (stax.Dense 300)
       stax.Relu
       (stax.Dense 300)
       stax.Relu
       (stax.Dense 300)
       stax.Relu
       (stax.Dense num-outputs)]
      train-net net
      test-net net
      conv False
      train-set "motion_blur"
      create-opt-step (fn [params] (nn-utils.create-opt-step calc-loss-train
                                                             (nn-utils.create-mgs-penalty-trace
                                                               train-apply
                                                               (get params "penalty"))
                                                             opt-update opt-get))
      param-space {"penalty" [2.5e-2 1e-2 5e-3 1e-3]}
      epochs 100)

(import [sklearn.model_selection [StratifiedShuffleSplit]])

(setv [train-images test-images train-labels test-labels] (mnist-data :train-set train-set :conv conv)
      input-shape (get (np.shape train-images) 1)
      [train-apply calc-loss-train opt-update opt-get new-opt-state] (mnist-train-net train-net input-shape :conv conv)
      [test-apply calc-loss-test] (mnist-test-net test-net)

      batch-size (default batch-size 32)
      epochs (default epochs 50)
      label-noise (default label-noise 0.5)
      train-size (default train-size 6000)

      splitter (StratifiedShuffleSplit :n-splits 5
                                       :train-size (/ train-size 60e3)
                                       :random-state 62)
      create-iterator (fn [] (gfor [train test] (.split splitter train-images train-labels)
                                   [(np.take train-images train :axis 0) ; train images
                                    test-images  ; test-images
                                    (-> (np.take train-labels train :axis 0)
                                        (nn-utils.add-label-noise label-noise)
                                        (nn-utils.one-hot-encode)) ; train labels
                                    (nn-utils.one-hot-encode test-labels)])) ; test labels
      metric (fn [opt-state x y] (calc-loss-test (opt-get opt-state) x y :rng (jax.random.PRNGKey 0)))
      perf (np.array (nn-utils.grid-search-cv param-space epochs create-opt-step new-opt-state
                                              batch-size metric create-iterator))
      best-param (get (np.take perf 0 :axis 1) (np.argmin (np.take perf 1 :axis 1))))

(print perf)
(print f"Results of grid search: {best-param}")
