@@ -18,6 +18,7 @@ ARG UBUNTU_MIRROR
1818ARG GITHUB_ARTIFACTORY=github.com
1919ARG INSTALL_FLASHINFER_JIT_CACHE=0
2020ARG FLASHINFER_VERSION=0.5.3
21+ ARG NVSHMEM_VERSION=3.4.5
2122
2223ENV DEBIAN_FRONTEND=noninteractive \
2324 CUDA_HOME=/usr/local/cuda \
@@ -131,21 +132,16 @@ RUN --mount=type=cache,target=/root/.cache/pip python3 -m pip install --upgrade
131132 && if [ "$INSTALL_FLASHINFER_JIT_CACHE" = "1" ]; then \
132133 python3 -m pip install flashinfer-jit-cache==${FLASHINFER_VERSION} --index-url https://flashinfer.ai/whl/cu${CUINDEX} ; \
133134 fi \
134- && if [ "${CUDA_VERSION%%.*}" = "12" ]; then \
135- python3 -m pip install nvidia-nccl-cu12==2.28.3 --force-reinstall --no-deps ; \
136- python3 -m pip install nvidia-cudnn-cu12==9.16.0.29 --force-reinstall --no-deps; \
137- elif [ "${CUDA_VERSION%%.*}" = "13" ]; then \
138- python3 -m pip install nvidia-nccl-cu13==2.28.3 --force-reinstall --no-deps ; \
139- else \
140- echo "No NCCL mapping for CUDA_VERSION=${CUDA_VERSION}" && exit 1 ; \
141- fi \
142135 && FLASHINFER_CUBIN_DOWNLOAD_THREADS=${BUILD_AND_DOWNLOAD_PARALLEL} FLASHINFER_LOGGING_LEVEL=warning python3 -m flashinfer --download-cubin
143136
144- # Download NVSHMEM source files
145137# We use Tom's DeepEP fork for GB200 for now; the 1fd57b0276311d035d16176bb0076426166e52f3 commit is https://github.com/fzyzcjy/DeepEP/tree/gb200_blog_part_2
146138RUN set -eux; \
147- if [ "${CUDA_VERSION%%.*}" != "13" ]; then \
148- pip install nvidia-nvshmem-cu12==3.4.5 ; \
139+ if [ "${CUDA_VERSION%%.*}" = "12" ]; then \
140+ pip install nvidia-nvshmem-cu12==${NVSHMEM_VERSION} ; \
141+ elif [ "${CUDA_VERSION%%.*}" = "13" ]; then \
142+ pip install nvidia-nvshmem-cu13==${NVSHMEM_VERSION} ; \
143+ else \
144+ echo "Unsupported CUDA version: $CUDA_VERSION" && exit 1 ; \
149145 fi && \
150146 if [ "$GRACE_BLACKWELL" = "1" ]; then \
151147 git clone https://github.com/fzyzcjy/DeepEP.git && \
@@ -198,6 +194,19 @@ RUN --mount=type=cache,target=/root/.cache/pip python3 -m pip install \
198194 nixl \
199195 py-spy
200196
197+ # Some patching packages
198+ # TODO: Remove this when torch version covers these packages
199+ RUN --mount=type=cache,target=/root/.cache/pip if [ "${CUDA_VERSION%%.*}" = "12" ]; then \
200+ python3 -m pip install nvidia-nccl-cu12==2.28.3 --force-reinstall --no-deps ; \
201+ python3 -m pip install nvidia-cudnn-cu12==9.16.0.29 --force-reinstall --no-deps; \
202+ python3 -m pip install nvidia-nvshmem-cu12==${NVSHMEM_VERSION} --force-reinstall --no-deps; \
203+ elif [ "${CUDA_VERSION%%.*}" = "13" ]; then \
204+ python3 -m pip install nvidia-nccl-cu13==2.28.3 --force-reinstall --no-deps ; \
205+ python3 -m pip install nvidia-cublas==13.1.0.3 --force-reinstall --no-deps ; \
206+ python3 -m pip install nixl-cu13 ; \
207+ python3 -m pip install nvidia-nvshmem-cu13==${NVSHMEM_VERSION} --force-reinstall --no-deps; \
208+ fi
209+
201210# Install development tools and utilities
202211RUN --mount=type=cache,target=/var/cache/apt apt-get update && apt-get install -y \
203212 gdb \
0 commit comments