(setv attach-dir ".")
(require [hy.contrib.walk [let]])

(import jax
        [jax.numpy :as jnp]
        [jax.experimental.stax :as stax]
        [neural_tangents :as nt]
        [neural_tangents [stax :as nt-stax]]
        [jax.experimental.optimizers :as optimizers]
        [jax.flatten_util [ravel_pytree]]
        [numpy :as np]
        [matplotlib.pyplot :as plt]
        [tqdm [tqdm trange]]
        [sklearn.model_selection [train_test_split]]
        [toolz.dicttoolz [merge]]
        [math [ceil]]
        [nn_utilities :as nn_utils]
        os
        pickle)

(defmacro bound? [x]
  `(try ~x
        (except [NameError] False)
        (else True)))

(defmacro default [x d]
  `(if (bound? ~x) ~x ~d))

(import [sklearn.preprocessing [normalize]]
        [imagecorruptions [corrupt]])

(defn partial-flatten [x]
  (np.reshape x (, (get (np.shape x) 0) -1)))

(defn fk-data [[train-size 1500] [conv False]]
  (setv images (np.transpose (get (np.load "../../facial_keypoints/face_images.npz") "face_images") (, 2 0 1))
        keypoints (np.genfromtxt "../../facial_keypoints/facial_keypoints.csv" :skip-header 1 :delimiter ",")
        non-nan (np.squeeze (np.where (np.logical-not (np.any (np.isnan keypoints) :axis 1))))
        images (np.take images non-nan :axis 0)
        images (.astype images "uint8")
        keypoints (np.take keypoints non-nan :axis 0))

  (, images keypoints (if conv (get (np.shape images) 1)
                          (* (get (np.shape images) 1)
                             (get (np.shape images) 2))))
  #_(, (np.take images (np.arange train-size) :axis 0)
       (np.take images (np.arange train-size (get (np.shape images) 0)) :axis 0)
       (np.take keypoints (np.arange train-size) :axis 0)
       (np.take keypoints (np.arange train-size (get (np.shape images) 0)) :axis 0)))

(defn fk-train-net [train-net input-shape [conv False] [optimizer None]]
  (setv
    [net-init net-apply] (stax.serial (unpack-iterable train-net))
    net-apply (jax.jit net-apply)
    [opt-init opt-update opt-get] (if (is optimizer None)
                                      (optimizers.sgd :step-size #_1e-5 (optimizers.exponential-decay
                                                                   :step-size 1e-4
                                                                   :decay-rate 0.99995
                                                                   :decay-steps 1))
                                      (hy.eval optimizer))
    ;; [opt-init opt-update opt-get] (optimizers.sgd :step-size step-size)
    calc-loss (jax.jit (fn [p x y [rng None]] (nn-utils.mse-loss (net-apply p x :rng rng) y)))
    new-opt-state (fn [rng] (opt-init (get (net-init rng (if conv
                                                             (, -1 input-shape input-shape 1)
                                                             (, -1 input-shape))) 1))))
  (, net-apply calc-loss opt-update opt-get new-opt-state))

(defn fk-test-net [test-net]
  (setv net-apply (jax.jit (get (stax.serial (unpack-iterable test-net)) 1))
        calc-loss (jax.jit (fn [p x y [rng None]] (nn-utils.mse-loss (net-apply p x :rng rng) y))))
  (, net-apply calc-loss))
(defn fk-augment [ims conv np-rng corruption [severity 1] [corruption-prob 1]]
  (setv to-corrupt (.binomial np-rng 1 corruption-prob (get (np.shape ims) 0)))
  (as-> ims I
        (if (is corruption None) I
            (jnp.array (lfor [i o] (enumerate I)
                             (if (= (get to-corrupt i) 1)
                                 (np.mean (corrupt o :corruption-name corruption
                                                   :severity severity) :axis 2)
                                 o))))
        (if conv (jnp.expand-dims I 3) (partial-flatten I))
        (/ I 255.0)))

(defn fk-eval [images keypoints train-apply test-apply
               calc-loss-test epochs opt-get opt-step new-opt-state jax-rng batch-size attach-dir
               [fname "perf.npy"] [fname2 "opt-states.obj"] [splitter None]
               [corruption None] [severity 1] [corruption-prob 1] [noise-scale 0.0]
               [show-progress True] [conv False] [np-rng-seed 33]]
  (setv ;; train-labels-one-hot (jax.nn.one-hot (nn-utils.add-label-noise train-labels label-noise) 10)
    ;; test-labels-one-hot (jax.nn.one-hot test-labels 10)
    np-rng-count (itertools.count np-rng-seed))
  (setv [_ #* subrng] (jax.random.split jax-rng (inc (.get-n-splits splitter)))
        perf (lfor [i data] (enumerate (.split splitter :X images :y keypoints))
                   (let [[train test] data
                         np-rng (np.random.default-rng :seed (next np-rng-count))
                         x-train (fk-augment (jnp.take images train :axis 0) conv np-rng corruption)
                         x-test (fk-augment (jnp.take images test :axis 0) conv np-rng None)
                         ;; x-train (/ (partial-flatten (jnp.take images train :axis 0)) 255.0)
                         ;; x-test (/ (partial-flatten (jnp.take images test :axis 0)) 255.0)
                         y-train (jnp.take keypoints train :axis 0)
                         y-train (np.clip (+ y-train #_
                                             (.uniform np-rng (- noise-scale) noise-scale :size (jnp.shape y-train))
                                             (.normal np-rng :scale noise-scale :size (jnp.shape y-train)))
                                          0.0 96.0)
                         y-test (jnp.take keypoints test :axis 0)
                         n (np.size train)
                         [opt-state metrics]
                         (nn-utils.train-model
                           epochs (new-opt-state (get subrng i))
                           opt-step x-train y-train
                           :batch-size batch-size
                           :metrics [(nn-utils.setup-loss-tracker
                                       calc-loss-test x-test y-test
                                       opt-get 100)
                                     #_(nn-utils.setup-accuracy-tracker
                                         test-images test-labels-one-hot test-apply
                                         opt-get True 100)
                                     (nn-utils.setup-trace-tracker
                                       train-apply opt-get
                                       100)
                                     (nn-utils.setup-determinant-tracker
                                       train-apply opt-get
                                       100)]
                           :show-progress show-progress
                           :progress-pos 2
                           :jax-rng (get subrng i))]
                     [(:metric (:state (get metrics 0)))
                      (:metric (:state (get metrics 1)))
                      (:metric (:state (get metrics 2)))
                      opt-state
                      #_(:metric (:state (get metrics 3)))]))
        loss (np.array (lfor p perf (get p 0)))
        ;; acc (np.array (lfor p perf (get p 1)))
        trace (np.array (lfor p perf (get p 1)))
        det (np.array (lfor p perf (get p 2)))
        opt-states (lfor p perf (get p 3))
        train-loss (np.take loss 0 :axis 2)
        ;; train-acc (np.take acc 0 :axis 2)
        test-loss (np.take loss 1 :axis 2)
        ;; test-acc (np.take acc 1 :axis 2)
        test-loss-mean (np.mean test-loss :axis 0)
        test-loss-std (np.std test-loss :axis 0)
        ;; test-acc-mean (np.mean test-acc :axis 0)
        ;; test-acc-std (np.std test-acc :axis 0)
        perf-file (os.path.join attach-dir fname)
        opts-file (os.path.join attach-dir fname2))
  (print f"Test loss {(get test-loss-mean -1) :.4f} \pm ({(get test-loss-std -1) :.4f})")
  ;; (print f"Test accuracy {(get test-acc-mean -1) :.4f} \pm ({(get test-acc-std -1)  :.4f})")
  (with [f (open perf-file "wb")]
    (np.save f train-loss)
    ;; (np.save f train-acc)
    (np.save f test-loss)
    ;; (np.save f test-acc)
    (np.save f trace)
    (np.save f det))
  (with [f (open opts-file "wb")]
    (pickle.dump (lfor o opt-states (optimizers.unpack-optimizer-state o)) f))
  opt-states)

(setv process-params `(do (setv num-outputs 30
                                train-net
                                [(stax.Conv :out-chan 6 :filter-shape (, 5 5) :padding "SAME")
                                 stax.Relu
                                 (stax.AvgPool :window-shape (, 2 2) :strides (, 2 2) :padding "VALID")
                                 (stax.Conv :out-chan 16 :filter-shape (, 5 5) :padding "SAME")
                                 stax.Relu
                                 (stax.AvgPool :window-shape (, 2 2) :strides (, 2 2) :padding "VALID")
                                 stax.Flatten
                                 (stax.Dense 120)
                                 stax.Relu
                                 (stax.Dense 84)
                                 stax.Relu
                                 (stax.Dense num-outputs)]
                                test-net
                                [(stax.Conv :out-chan 6 :filter-shape (, 5 5) :padding "SAME")
                                 stax.Relu
                                 (stax.AvgPool :window-shape (, 2 2) :strides (, 2 2) :padding "VALID")
                                 (stax.Conv :out-chan 16 :filter-shape (, 5 5) :padding "SAME")
                                 stax.Relu
                                 (stax.AvgPool :window-shape (, 2 2) :strides (, 2 2) :padding "VALID")
                                 stax.Flatten
                                 (stax.Dense 120)
                                 stax.Relu
                                 (stax.Dense 84)
                                 stax.Relu
                                 (stax.Dense num-outputs)]
                                create-penalty (fn [net-apply calc-loss] (constantly 0.0))
                                conv True
                                epochs 50
                                batch-size 128
                                jax-rng (jax.random.PRNGKey 0))
                          (for [i (range (dec (// (len net) 2)))]
                            (.insert train-net (+ 2 i (* i 2)) (nt-stax.Dropout :rate 0.9)))
                          (for [i (range (dec (// (len net) 2)))]
                            (.insert test-net (+ 2 i (* i 2)) (nt-stax.Dropout :rate 0.9 :mode "test"))))
      noise-scale [0 10 20 30]
      train-size [0.3])

(import [sklearn.model_selection [ShuffleSplit]]
        [multiprocessing :as mp])

(defn run-testbench [d q]
  (assoc os.environ "CUDA_VISIBLE_DEVICES" (str d))

  (hy.eval process-params (globals))

  (print (jax.devices "gpu"))
  (setv [images keypoints input-shape] (fk-data :conv conv)
        [train-apply calc-loss-train opt-update opt-get new-opt-state] (fk-train-net train-net input-shape
                                                                                     :conv conv )
        opt-step (nn-utils.create-opt-step calc-loss-train (create-penalty train-apply calc-loss-train) opt-update opt-get)
        [test-apply calc-loss-test] (fk-test-net test-net))

  (while (not (.empty q))
    (setv [n t] (.get q)
          splitter (ShuffleSplit :n-splits 10
                                 :train-size t
                                 :random-state 2))
    (fk-eval images keypoints train-apply test-apply
             calc-loss-test epochs opt-get opt-step new-opt-state jax-rng
             batch-size attach-dir f"perf_{n}_{t :.1f}.npy" f"opt-states_{n}_{t :.1f}.obj"
             splitter :noise-scale n :conv conv :show-progress True
             :corruption "defocus_blur")))

(defmain [#* args]
  (mp.set-start-method "spawn")
  (setv [g1 g2] (np.meshgrid noise-scale train-size)
        grid (np.vstack [(np.reshape g1 (, -1)) (np.reshape g2 (, -1))])
        q (mp.Queue)
        procs (list))
  (for [i (range (get (np.shape grid) 1))]
    (.put q (np.take grid i :axis 1)))

  (setv s (.qsize q))

  (for [d (range (int (get args 1)))]
    (setv p (mp.Process :target run-testbench :args (, d q)))
    (.start p)
    (.append procs p))

  (with [t (trange (.qsize q))]
    (while (not (.empty q))
      (setv ss (.qsize q))
      (when (> s ss)
        (.update t (- s ss))
        (setv s ss))))

  (mp.connection.wait (lfor p procs (. p sentinel))))
