absl-py==1.0.0
clu==0.0.6
flax==0.6.0
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
jax[cuda11_cudnn805]>=0.3.16  # change to jax[tpu] if running on tpus
ml-collections==0.1.0
numpy==1.22.0
optax==0.1.0
sentencepiece==0.1.96
six==1.15.0
tensorflow==2.11.1
tensorflow-datasets==4.4.0
tensorflow-text==2.8.1
