(setv attach-dir ".")
(require [hy.contrib.walk [let]])

(import jax
        [jax.numpy :as jnp]
        [jax.experimental.stax :as stax]
        [neural_tangents :as nt]
        [neural_tangents [stax :as nt-stax]]
        [jax.experimental.optimizers :as optimizers]
        [jax.flatten_util [ravel_pytree]]
        [numpy :as np]
        [matplotlib.pyplot :as plt]
        [tqdm [tqdm trange]]
        [sklearn.model_selection [train_test_split]]
        [toolz.dicttoolz [merge]]
        [math [ceil]]
        [nn_utilities :as nn_utils]
        os
        pickle)

(defmacro bound? [x]
  `(try ~x
        (except [NameError] False)
        (else True)))

(defmacro default [x d]
  `(if (bound? ~x) ~x ~d))

(defmain [[#** args]]
  (setv attach-dir ".")
(import [sklearn.datasets [make_circles]])

(defn two-circles-setup [net step-size [class-sizes [150 75]] [label-noise 0.0]
                         [data-noise 0.15] [data-seed 2] [np-rng-seed 3]]
  (setv
    [x y] (make-circles :n-samples class-sizes
                        :noise data-noise
                        :random-state data-seed)
    np-rng (np.random.default-rng np-rng-seed)
    n (get (np.shape y) 0)
    flipped-idx (.choice np-rng (range 0 n) (int (np.floor (* label-noise n))) :replace False)
    flipped-labels (.integers np-rng 0 (inc (np.max y)) (get (np.shape flipped-idx) 0))
    _ (assoc y flipped-idx flipped-labels)
    y (np.reshape y (, -1 1))

    [net-init net-apply _] (nt-stax.serial (unpack-iterable net))
    net-apply (jax.jit net-apply)
    [opt-init opt-update opt-get] (optimizers.sgd :step-size step-size)
    calc-loss (fn [p x y [rng None]] (nn-utils.bce-with-logits-loss (net-apply p x :rng rng) y))
    input-shape (get (np.shape x) 1)
    new-opt-state (fn [rng] (opt-init (get (net-init rng (, -1 input-shape)) 1))))
  (, x y net-apply calc-loss opt-update opt-get new-opt-state))
(defn two-circles-eval [x-train x-test y-train y-test net-apply calc-loss epochs
                        opt-get opt-step new-opt-state batch-size attach-dir
                        [rng (jax.random.PRNGKey 0)]]
  (setv [opt-state metrics] (nn-utils.train-model epochs (new-opt-state rng) opt-step x-train y-train
                                                  :jax-rng rng
                                                  :batch-size batch-size
                                                  :metrics [(nn-utils.setup-loss-tracker
                                                              calc-loss x-test y-test
                                                              opt-get 100)
                                                            (nn-utils.setup-accuracy-tracker
                                                              x-test y-test net-apply
                                                              opt-get False 100)
                                                            (nn-utils.setup-trace-tracker
                                                              net-apply opt-get 100)
                                                            (nn-utils.setup-determinant-tracker
                                                              net-apply opt-get 100)
                                                            #_(nn-utils.setup-eigval-tracker
                                                              net-apply opt-get 200
                                                              (if (= batch-size -1)
                                                                  (np.size y-train)
                                                                  batch-size)
                                                              :normalise False
                                                              :rng rng)]))
  (with [f (open (os.path.join attach-dir "opt-state.obj") "wb")]
    (pickle.dump (optimizers.unpack-optimizer-state opt-state) f))
  (with [f (open (os.path.join attach-dir "perf.npy") "wb")]
    (np.save f (:metric (:state (get metrics 0)))))
  (with [f (open (os.path.join attach-dir "accuracy.npy") "wb")]
    (np.save f (:metric (:state (get metrics 1)))))
  (with [f (open (os.path.join attach-dir "ntk-trace.npy") "wb")]
    (np.save f (:metric (:state (get metrics 2)))))
  (with [f (open (os.path.join attach-dir "ntk-determinant.npy") "wb")]
    (np.save f (:metric (:state (get metrics 3))))))

(setv net [(nt-stax.Dense 300 :parameterization "standard")
           (nt-stax.Relu)
           (nt-stax.Dense 300 :parameterization "standard")
           (nt-stax.Relu)
           (nt-stax.Dense 300 :parameterization "standard")
           (nt-stax.Relu)
           (nt-stax.Dense 1 :parameterization "standard")])

(setv step-size (optimizers.exponential-decay :step-size 1e-2
                                              :decay-rate 0.99995
                                              :decay-steps 1)
      batch-size 32
      epochs 4000
      [x y net-apply calc-loss opt-update
       opt-get new-opt-state] (two-circles-setup net step-size)
       opt-step (nn-utils.create-opt-step calc-loss (constantly 0.0) opt-update opt-get)
      calc-accuracy (nn-utils.create-accuracy-metric net-apply True)
      [x-train x-test y-train y-test] (train-test-split x y
                                                        :test-size 0.5
                                                        :random-state 3))

(two-circles-eval x-train x-test y-train y-test net-apply calc-loss epochs
                  opt-get opt-step new-opt-state batch-size attach-dir)
