--extra-index-url https://download.pytorch.org/whl/cpu

--find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

cloudpickle
dm_haiku>=0.0.10
e3nn_jax>=0.20.0
h5py
jax[cuda12_pip]==0.4.18
jax_md>=0.2.8
jmp>=0.0.4
jraph>=0.0.6.dev0
matscipy>=0.8.0
optax>=0.1.7
ott-jax>=0.4.2
pyvista
PyYAML
torch>=2.1.0+cpu
wandb
wget