Skip to content

Commit 02580c2

Browse files
authored
ci: bump TF to 2.18, PT to 2.5 (#4228)
This is prepared for the upcoming TF 2.18, which needs CUDNN 9. In the future, I may move all pinnings into pyproject.toml... <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Enhanced dependency management for CUDA and Python workflows. - Introduced new jobs for better organization of test duration handling. - **Bug Fixes** - Updated TensorFlow and Torch versions for improved compatibility and performance. - Refined version requirements for TensorFlow based on detected CUDA versions. - **Documentation** - Adjusted testing commands and linting configurations for clarity and compliance. - **Chores** - Streamlined caching mechanisms to optimize test duration tracking. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
1 parent 4d50048 commit 02580c2

5 files changed

Lines changed: 25 additions & 9 deletions

File tree

.github/workflows/test_cuda.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ jobs:
4747
&& sudo apt-get -y install cuda-12-3 libcudnn8=8.9.5.*-1+cuda12.3
4848
if: false # skip as we use nvidia image
4949
- run: python -m pip install -U uv
50-
- run: source/install/uv_with_retry.sh pip install --system "tensorflow>=2.15.0rc0" "torch==2.3.1.*"
50+
- run: source/install/uv_with_retry.sh pip install --system "tensorflow~=2.18.0rc2" "torch~=2.5.0"
5151
- run: |
5252
export PYTORCH_ROOT=$(python -c 'import torch;print(torch.__path__[0])')
5353
export TENSORFLOW_ROOT=$(python -c 'import importlib,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)')
@@ -63,7 +63,7 @@ jobs:
6363
CUDA_VISIBLE_DEVICES: 0
6464
- name: Download libtorch
6565
run: |
66-
wget https://download.pytorch.org/libtorch/cu121/libtorch-cxx11-abi-shared-with-deps-2.2.1%2Bcu121.zip -O libtorch.zip
66+
wget https://download.pytorch.org/libtorch/cu124/libtorch-cxx11-abi-shared-with-deps-2.5.0%2Bcu124.zip -O libtorch.zip
6767
unzip libtorch.zip
6868
- run: |
6969
export CMAKE_PREFIX_PATH=$GITHUB_WORKSPACE/libtorch

.github/workflows/test_python.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ jobs:
2626
- run: python -m pip install -U uv
2727
- run: |
2828
source/install/uv_with_retry.sh pip install --system mpich
29-
source/install/uv_with_retry.sh pip install --system "torch==2.3.0+cpu.cxx11.abi" -i https://download.pytorch.org/whl/
29+
source/install/uv_with_retry.sh pip install --system torch -i https://download.pytorch.org/whl/cpu
3030
export PYTORCH_ROOT=$(python -c 'import torch;print(torch.__path__[0])')
3131
source/install/uv_with_retry.sh pip install --system --only-binary=horovod -e .[cpu,test,jax] horovod[tensorflow-cpu] mpi4py
3232
env:

backend/find_pytorch.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22
import importlib
33
import os
4+
import platform
45
import site
56
from functools import (
67
lru_cache,
@@ -22,6 +23,9 @@
2223
Union,
2324
)
2425

26+
from packaging.specifiers import (
27+
SpecifierSet,
28+
)
2529
from packaging.version import (
2630
Version,
2731
)
@@ -104,6 +108,20 @@ def get_pt_requirement(pt_version: str = "") -> dict:
104108
"""
105109
if pt_version is None:
106110
return {"torch": []}
111+
if (
112+
os.environ.get("CIBUILDWHEEL", "0") == "1"
113+
and platform.system() == "Linux"
114+
and platform.machine() == "x86_64"
115+
):
116+
cuda_version = os.environ.get("CUDA_VERSION", "12.2")
117+
if cuda_version == "" or cuda_version in SpecifierSet(">=12,<13"):
118+
# CUDA 12.2, cudnn 9
119+
pt_version = "2.5.0"
120+
elif cuda_version in SpecifierSet(">=11,<12"):
121+
# CUDA 11.8, cudnn 8
122+
pt_version = "2.3.1"
123+
else:
124+
raise RuntimeError("Unsupported CUDA version") from None
107125
if pt_version == "":
108126
pt_version = os.environ.get("PYTORCH_VERSION", "")
109127

backend/find_tensorflow.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,14 +85,14 @@ def find_tensorflow() -> tuple[Optional[str], list[str]]:
8585
if os.environ.get("CIBUILDWHEEL", "0") == "1":
8686
cuda_version = os.environ.get("CUDA_VERSION", "12.2")
8787
if cuda_version == "" or cuda_version in SpecifierSet(">=12,<13"):
88-
# CUDA 12.2
88+
# CUDA 12.2, cudnn 9
8989
requires.extend(
9090
[
91-
"tensorflow-cpu>=2.15.0rc0; platform_machine=='x86_64' and platform_system == 'Linux'",
91+
"tensorflow-cpu>=2.18.0rc0; platform_machine=='x86_64' and platform_system == 'Linux'",
9292
]
9393
)
9494
elif cuda_version in SpecifierSet(">=11,<12"):
95-
# CUDA 11.8
95+
# CUDA 11.8, cudnn 8
9696
requires.extend(
9797
[
9898
"tensorflow-cpu>=2.5.0rc0,<2.15; platform_machine=='x86_64' and platform_system == 'Linux'",

pyproject.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ cu12 = [
132132
"nvidia-curand-cu12",
133133
"nvidia-cusolver-cu12",
134134
"nvidia-cusparse-cu12",
135-
"nvidia-cudnn-cu12<9",
135+
"nvidia-cudnn-cu12",
136136
"nvidia-cuda-nvcc-cu12",
137137
]
138138
jax = [
@@ -279,8 +279,6 @@ PATH = "/usr/lib64/mpich/bin:$PATH"
279279
UV_EXTRA_INDEX_URL = "https://download.pytorch.org/whl/cpu"
280280
# trick to find the correction version of mpich
281281
CMAKE_PREFIX_PATH="/opt/python/cp311-cp311/"
282-
# PT 2.4.0 requires cudnn 9, incompatible with TF with cudnn 8
283-
PYTORCH_VERSION = "2.3.1"
284282

285283
[tool.cibuildwheel.windows]
286284
test-extras = ["cpu", "torch"]

0 commit comments

Comments
 (0)