# Tensorflow cpu-only version.
tensorflow-cpu~=2.18
tensorflow-text~=2.18

# Torch cpu-only version.
--extra-index-url https://download.pytorch.org/whl/cpu
torch>=2.1.0
torchvision>=0.16.0

# Jax with cuda support.
# Keep same version as Keras repo.
--find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
jax[cuda12_pip]==0.4.28

-r requirements-common.txt
