diff --git a/infra/tpu-pytorch-releases/artifacts.auto.tfvars b/infra/tpu-pytorch-releases/artifacts.auto.tfvars index 45c08847e858..e8ad37fa3fe7 100644 --- a/infra/tpu-pytorch-releases/artifacts.auto.tfvars +++ b/infra/tpu-pytorch-releases/artifacts.auto.tfvars @@ -1,4 +1,4 @@ -#### Historical builds for releases before r2.7 +#### Historical builds for releases before r2.7. manual_nightly_builds = [ ] diff --git a/infra/tpu-pytorch-releases/artifacts_builds.tf b/infra/tpu-pytorch-releases/artifacts_builds.tf index 4388f6c18418..6709b380bcfd 100644 --- a/infra/tpu-pytorch-releases/artifacts_builds.tf +++ b/infra/tpu-pytorch-releases/artifacts_builds.tf @@ -1,7 +1,7 @@ ########## Begin section for release and nightly ######## # Define common configuration parameters for 2.7 release and nightly locals { - tpu_python_versions = ["3.9", "3.10", "3.11"] + tpu_python_versions = ["3.9", "3.10", "3.11", "3.12"] release_git_tag = "v2.7.0-rc5" release_package_version = "2.7.0-rc5" release_pytorch_git_rev = "v2.7.0-rc10" diff --git a/setup.py b/setup.py index cb78389c65a8..161505efcb49 100644 --- a/setup.py +++ b/setup.py @@ -396,7 +396,7 @@ def _get_jax_install_requirements(): # Install nightly JAX libraries from the JAX package registries. jax = f'jax @ https://storage.googleapis.com/jax-releases/nightly/jax/jax-{_jax_version}-py3-none-any.whl' jaxlib = [] - for python_minor_version in [9, 10, 11]: + for python_minor_version in [9, 10, 11, 12]: jaxlib.append( f'jaxlib @ https://storage.googleapis.com/jax-releases/nightly/nocuda/jaxlib-{_jaxlib_version}-cp3{python_minor_version}-cp3{python_minor_version}-manylinux2014_x86_64.whl ; python_version == "3.{python_minor_version}"' )