torch>=2.1.0
jax[cuda11_pip]
jaxlib
optax
tqdm>=4.62.3
transformers>=4.37.0
fire>=0.5.0
numpy>=1.21.2
einops>=0.6.1
dm-tree>=0.1.8
gradio>=3.23.0
rich
