Skip to content

Commit 5eb3d0b

Browse files
committed
chore: bump pytorch to 2.8
1 parent 3be0755 commit 5eb3d0b

6 files changed

Lines changed: 6 additions & 6 deletions

File tree

.devcontainer/download_libtorch.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,5 @@ set -ev
44
SCRIPT_PATH=$(dirname $(realpath -s $0))
55
cd ${SCRIPT_PATH}/..
66

7-
wget https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-2.7.0%2Bcpu.zip -O ~/libtorch.zip
7+
wget https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-2.8.0%2Bcpu.zip -O ~/libtorch.zip
88
unzip ~/libtorch.zip

.github/workflows/build_cc.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ jobs:
3333
- uses: lukka/get-cmake@latest
3434
- run: python -m pip install uv
3535
- run: source/install/uv_with_retry.sh pip install --system tensorflow
36-
- run: source/install/uv_with_retry.sh pip install --system 'torch==2.7' --index-url https://download.pytorch.org/whl/cpu
36+
- run: source/install/uv_with_retry.sh pip install --system 'torch==2.8' --index-url https://download.pytorch.org/whl/cpu
3737
- run: |
3838
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.0-1_all.deb \
3939
&& sudo dpkg -i cuda-keyring_1.0-1_all.deb \

.github/workflows/codeql.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ jobs:
4141
&& sudo apt-get update \
4242
&& sudo apt-get -y install cuda-cudart-dev-12-2 cuda-nvcc-12-2
4343
python -m pip install tensorflow
44-
python -m pip install 'torch==2.7' --index-url https://download.pytorch.org/whl/cpu
44+
python -m pip install 'torch==2.8' --index-url https://download.pytorch.org/whl/cpu
4545
env:
4646
DEBIAN_FRONTEND: noninteractive
4747
# Initializes the CodeQL tools for scanning.

.github/workflows/test_cc.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ jobs:
2828
source/install/uv_with_retry.sh pip install --system tensorflow-cpu~=2.18.0 jax==0.5.0
2929
export TENSORFLOW_ROOT=$(python -c 'import importlib,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)')
3030
source/install/uv_with_retry.sh pip install --system -e .[cpu,test,lmp,jax] mpi4py mpich
31-
source/install/uv_with_retry.sh pip install --system 'torch==2.7' --index-url https://download.pytorch.org/whl/cpu
31+
source/install/uv_with_retry.sh pip install --system 'torch==2.8' --index-url https://download.pytorch.org/whl/cpu
3232
- name: Convert models
3333
run: source/tests/infer/convert-models.sh
3434
# https://github.com/actions/runner-images/issues/9491

.github/workflows/test_cuda.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ jobs:
4343
&& sudo apt-get -y install cuda-12-3 libcudnn8=8.9.5.*-1+cuda12.3
4444
if: false # skip as we use nvidia image
4545
- run: python -m pip install -U uv
46-
- run: source/install/uv_with_retry.sh pip install --system "tensorflow~=2.18.0rc2" "torch~=2.7.0" "jax[cuda12]==0.5.0"
46+
- run: source/install/uv_with_retry.sh pip install --system "tensorflow~=2.18.0rc2" "torch~=2.8.0" "jax[cuda12]==0.5.0"
4747
- run: |
4848
export PYTORCH_ROOT=$(python -c 'import torch;print(torch.__path__[0])')
4949
export TENSORFLOW_ROOT=$(python -c 'import importlib,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)')

backend/find_pytorch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def get_pt_requirement(pt_version: str = "") -> dict:
116116
cuda_version = os.environ.get("CUDA_VERSION", "12.2")
117117
if cuda_version == "" or cuda_version in SpecifierSet(">=12,<13"):
118118
# CUDA 12.2, cudnn 9
119-
pt_version = "2.7.0"
119+
pt_version = "2.8.0"
120120
elif cuda_version in SpecifierSet(">=11,<12"):
121121
# CUDA 11.8, cudnn 8
122122
pt_version = "2.3.1"

0 commit comments

Comments
 (0)