absl-py==1.0.0
clu==0.0.6
flax==0.4.1
jax==0.3.4
--find-links https://storage.googleapis.com/jax-releases/jax_releases.html
jaxlib==0.3.2+cuda11.cudnn82  # Make sure CUDA version matches the base image.
ml-collections==0.1.0
numpy==1.22.0
optax==0.1.0
tensorflow==2.11.1
tensorflow-datasets==4.4.0
