Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 109 additions & 4 deletions .github/actions/build-container/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,14 @@ inputs:
description: "URL of the Bazel remote cache to use for building the image"
required: true
default: ""
ENABLE_BAZEL_DISK_CACHE:
description: "Enable Bazel disk cache via actions/cache"
required: false
default: "false"
ENABLE_BAZEL_REPO_CACHE:
description: "Enable Bazel repository cache via actions/cache"
required: false
default: "false"

outputs:
DOCKER_TAG_MEALKIT:
Expand All @@ -72,6 +80,15 @@ runs:
run: |
echo 'UPLD_IMAGE=ghcr.io/nvidia/jax-toolbox-internal' >> $GITHUB_ENV
echo "BADGE_FILENAME_FULL=${{ inputs.BADGE_FILENAME }}-${{ inputs.ARCHITECTURE }}.json" >> $GITHUB_ENV
# Cap Docker client API version to match the daemon on NVKS runners
echo 'DOCKER_API_VERSION=1.43' >> $GITHUB_ENV
# When disk cache is enabled use the BuildKit cache mount path;
# otherwise fall back to the remote cache URL (internal infra runners).
if [[ "${{ inputs.ENABLE_BAZEL_DISK_CACHE }}" == "true" ]]; then
echo 'BAZEL_CACHE_ARG=BAZEL_CACHE=/cache/bazel-disk' >> $GITHUB_ENV
else
echo 'BAZEL_CACHE_ARG=BAZEL_CACHE=${{ inputs.bazel-remote-cache-url }}' >> $GITHUB_ENV
fi

- name: Setup SSH
id: setup-ssh
Expand All @@ -91,8 +108,7 @@ runs:
uses: docker/setup-buildx-action@v3
with:
driver-opts: |
image=moby/buildkit:v0.12.1
version: v0.30.1
image=moby/buildkit:v0.19.0

- name: Download nsys-jax version.py
uses: actions/download-artifact@v4
Expand All @@ -106,6 +122,38 @@ runs:
mv version.py .github/container/nsys_jax/nsys_jax/
cat .github/container/nsys_jax/nsys_jax/version.py

# BAZEL CACHE RESTORE
- name: Restore Bazel disk cache
if: inputs.ENABLE_BAZEL_DISK_CACHE == 'true'
uses: actions/cache/restore@v4
with:
path: /tmp/bazel-disk.tar
key: bazel-disk-cache-${{ inputs.ARCHITECTURE }}-${{ github.run_id }}
restore-keys: |
bazel-disk-cache-${{ inputs.ARCHITECTURE }}-

- name: Restore Bazel repo cache
if: inputs.ENABLE_BAZEL_REPO_CACHE == 'true'
uses: actions/cache/restore@v4
with:
path: /tmp/bazel-repo.tar
key: bazel-repo-cache-${{ inputs.ARCHITECTURE }}-${{ github.run_id }}
restore-keys: |
bazel-repo-cache-${{ inputs.ARCHITECTURE }}-

# Extract restored tars into seed dirs; create empty dirs on first run
- name: Prepare Bazel cache seed directories
shell: bash
run: |
mkdir -p /tmp/bazel-disk-cache
if [[ -f /tmp/bazel-disk.tar ]]; then
tar -xf /tmp/bazel-disk.tar -C /tmp/bazel-disk-cache
fi
mkdir -p /tmp/bazel-repo-cache
if [[ -f /tmp/bazel-repo.tar ]]; then
tar -xf /tmp/bazel-repo.tar -C /tmp/bazel-repo-cache
fi

# MEALKIT BUILD
- name: Set docker metadata - mealkit
id: mealkit-metadata
Expand Down Expand Up @@ -134,9 +182,11 @@ runs:
ssh: default
secret-files: |
"SSH_KNOWN_HOSTS=${{ steps.setup-ssh.outputs.known-hosts-file }}"
build-contexts: |
bazel-disk-seed=/tmp/bazel-disk-cache
build-args: |
BASE_IMAGE=${{ inputs.BASE_IMAGE }}
BAZEL_CACHE=${{ inputs.bazel-remote-cache-url }}
${{ env.BAZEL_CACHE_ARG }}
BUILD_DATE=${{ inputs.BUILD_DATE }}
${{ inputs.EXTRA_BUILD_ARGS }}
# FINAL IMAGE BUILD
Expand Down Expand Up @@ -169,10 +219,65 @@ runs:
"SSH_KNOWN_HOSTS=${{ steps.setup-ssh.outputs.known-hosts-file }}"
build-args: |
BASE_IMAGE=${{ inputs.BASE_IMAGE }}
BAZEL_CACHE=${{ inputs.bazel-remote-cache-url }}
${{ env.BAZEL_CACHE_ARG }}
BUILD_DATE=${{ inputs.BUILD_DATE }}
${{ inputs.EXTRA_BUILD_ARGS }}

# BAZEL CACHE EXPORT
# Snapshots are captured first; prune runs after to free space before upload.
# type=tar streams a single archive instead of per-file copies
- name: Export Bazel disk cache
if: inputs.ENABLE_BAZEL_DISK_CACHE == 'true'
uses: docker/build-push-action@v5
with:
context: ${{ inputs.DOCKER_CONTEXT }}
push: false
file: ${{ inputs.DOCKERFILE }}
platforms: linux/${{ inputs.ARCHITECTURE }}
target: bazel-disk-export
outputs: type=tar,dest=/tmp/bazel-disk.tar
build-contexts: |
bazel-disk-seed=/tmp/bazel-disk-cache
build-args: |
BASE_IMAGE=${{ inputs.BASE_IMAGE }}
BUILD_DATE=${{ inputs.BUILD_DATE }}
${{ inputs.EXTRA_BUILD_ARGS }}

- name: Export Bazel repo cache
if: inputs.ENABLE_BAZEL_REPO_CACHE == 'true'
uses: docker/build-push-action@v5
with:
context: ${{ inputs.DOCKER_CONTEXT }}
push: false
file: ${{ inputs.DOCKERFILE }}
platforms: linux/${{ inputs.ARCHITECTURE }}
target: bazel-repo-export
outputs: type=tar,dest=/tmp/bazel-repo.tar
build-args: |
BASE_IMAGE=${{ inputs.BASE_IMAGE }}
BUILD_DATE=${{ inputs.BUILD_DATE }}
${{ inputs.EXTRA_BUILD_ARGS }}

# Prune layer cache after snapshots are captured to free disk space before upload
- name: Prune BuildKit layer cache before upload
if: inputs.ENABLE_BAZEL_DISK_CACHE == 'true' || inputs.ENABLE_BAZEL_REPO_CACHE == 'true'
shell: bash
run: docker buildx prune --force

- name: Save Bazel disk cache
if: inputs.ENABLE_BAZEL_DISK_CACHE == 'true'
uses: actions/cache/save@v4
with:
path: /tmp/bazel-disk.tar
key: bazel-disk-cache-${{ inputs.ARCHITECTURE }}-${{ github.run_id }}

- name: Save Bazel repo cache
if: inputs.ENABLE_BAZEL_REPO_CACHE == 'true'
uses: actions/cache/save@v4
with:
path: /tmp/bazel-repo.tar
key: bazel-repo-cache-${{ inputs.ARCHITECTURE }}-${{ github.run_id }}

# SITREP GENERATION
- name: Generate sitrep
if: "!cancelled()"
Expand Down
45 changes: 43 additions & 2 deletions .github/container/Dockerfile.jax
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,24 @@ ARG SRC_PATH_TRANSFORMER_ENGINE=/opt/transformer-engine
ARG GIT_USER_NAME="JAX Toolbox"
ARG GIT_USER_EMAIL=jax@nvidia.com

ARG BAZEL_CACHE=/tmp
ARG BAZEL_CACHE=/cache/bazel-disk
ARG BUILD_DATE

###############################################################################
## Bazel disk cache seed (overridden via --build-context on cache hit)
###############################################################################

# On first run this is empty (FROM scratch). When actions/cache restores a
# previous disk cache to /tmp/bazel-disk-cache on the runner, the caller passes
# --build-context bazel-disk-seed=/tmp/bazel-disk-cache to inject it.
FROM scratch AS bazel-disk-seed

###############################################################################
## Build JAX
###############################################################################

FROM ${BASE_IMAGE} AS builder
ARG TARGETARCH
ARG URLREF_JAX
ARG URLREF_TRANSFORMER_ENGINE
ARG URLREF_XLA
Expand Down Expand Up @@ -54,9 +64,14 @@ RUN ARCH="$(dpkg --print-architecture)" && \
chmod +x /usr/local/bin/bazel
# Populate ${BUILD_PATH_JAXLIB} with editable wheels; --no-install because
# (a) this is the builder stage, and (b) pip-finalize.sh does the install
RUN mkdir -p /builder/extra-targets/{bin,python} && \
RUN --mount=type=cache,id=bazel-disk-${TARGETARCH},target=/cache/bazel-disk,sharing=locked \
--mount=type=cache,id=bazel-repo-${TARGETARCH},target=/cache/bazel-repo,sharing=locked \
--mount=type=bind,from=bazel-disk-seed,source=.,target=/tmp/bazel-disk-seed,readonly \
cp -a /tmp/bazel-disk-seed/. /cache/bazel-disk/ 2>/dev/null || true && \
mkdir -p /builder/extra-targets/{bin,python} && \
build-jax.sh \
--bazel-cache ${BAZEL_CACHE} \
--build-param --bazel_options=--repository_cache=/cache/bazel-repo \
--build-path-jaxlib ${BUILD_PATH_JAXLIB} \
--extra-targets "${EXTRA_BAZEL_TARGETS}" \
--extra-target-dest /builder/extra-targets \
Expand Down Expand Up @@ -148,3 +163,29 @@ RUN install-nsys-jax.sh ${SRC_PATH_NSYS_JAX}

FROM mealkit AS final
RUN pip-finalize.sh

###############################################################################
## Bazel cache export stages (used by CI to persist caches via actions/cache)
###############################################################################

# ARG BUILD_DATE ensures this always re-executes (never a registry cache hit),
# so the snapshot always reflects the current run's cache mount content.
FROM ${BASE_IMAGE} AS bazel-disk-snapshot
ARG TARGETARCH
ARG BUILD_DATE
RUN --mount=type=cache,id=bazel-disk-${TARGETARCH},target=/cache/bazel-disk,sharing=locked,readonly \
mkdir -p /bazel-disk-snapshot && \
cp -rp /cache/bazel-disk/. /bazel-disk-snapshot/

FROM scratch AS bazel-disk-export
COPY --from=bazel-disk-snapshot /bazel-disk-snapshot /

FROM ${BASE_IMAGE} AS bazel-repo-snapshot
ARG TARGETARCH
ARG BUILD_DATE
RUN --mount=type=cache,id=bazel-repo-${TARGETARCH},target=/cache/bazel-repo,sharing=locked,readonly \
mkdir -p /bazel-repo-snapshot && \
cp -rp /cache/bazel-repo/. /bazel-repo-snapshot/

FROM scratch AS bazel-repo-export
COPY --from=bazel-repo-snapshot /bazel-repo-snapshot /
48 changes: 48 additions & 0 deletions .github/eks-workflow-files/jax/test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
apiVersion: batch/v1
kind: Job
metadata:
name: PLACEHOLDER
labels:
kueue.x-k8s.io/queue-name: p5-queue
kueue.x-k8s.io/max-exec-time-seconds: "10800"
spec:
template:
spec:
restartPolicy: Never
containers:
- name: jax
image: PLACEHOLDER
command:
- bash
- -c
- |
set -exo pipefail

LOG_DIR="/output/${RUN_ID}"
mkdir -p ${LOG_DIR}

# backend-independent tests
test-jax.sh -b backend-independent 2>&1 | tee ${LOG_DIR}/test-backend-independent.log

# single-gpu tests
nvidia-cuda-mps-control -d
test-jax.sh -b single-gpu 2>&1 | tee ${LOG_DIR}/test-single-gpu.log

# multi-gpu tests
test-jax.sh -b multi-gpu 2>&1 | tee ${LOG_DIR}/test-multi-gpu.log
env:
- name: RUN_ID
value: PLACEHOLDER
resources:
limits:
nvidia.com/gpu: 8
volumeMounts:
- name: s3-storage
mountPath: /output
subPath: jax
imagePullSecrets:
- name: PLACEHOLDER
volumes:
- name: s3-storage
persistentVolumeClaim:
claimName: s3-pvc
69 changes: 69 additions & 0 deletions .github/eks-workflow-files/maxtext/test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
apiVersion: batch/v1
kind: Job
metadata:
name: PLACEHOLDER
labels:
kueue.x-k8s.io/queue-name: p5-queue
kueue.x-k8s.io/max-exec-time-seconds: "10800"
spec:
template:
spec:
restartPolicy: Never
containers:
- name: maxtext
image: PLACEHOLDER
command:
- bash
- -c
- |
set -exo pipefail

LOG_DIR="/output/${RUN_ID}"
mkdir -p ${LOG_DIR}

# single-process-multi-device: PP=1, DP=1, FSDP=2, TP=4
test-maxtext.sh \
--output ${LOG_DIR}/1DP2FSDP4TP1PP_single_process \
--dtype bfloat16 \
--mem-fraction 0.65 \
--decoder-block default \
--attn-type dot_product \
--batch-per-gpu 2 \
--steps 10 \
--pipeline-parallel 1 \
--data-parallel 1 \
--fsdp 2 \
--tensor-parallel 4 \
--nodes 1

# multi-process: PP=1, DP=2, FSDP=2, TP=2
test-maxtext.sh \
--output ${LOG_DIR}/2DP2FSDP2TP1PP \
--dtype bfloat16 \
--mem-fraction 0.65 \
--decoder-block default \
--attn-type dot_product \
--batch-per-gpu 2 \
--steps 10 \
--pipeline-parallel 1 \
--data-parallel 2 \
--fsdp 2 \
--tensor-parallel 2 \
--nodes 1 \
--multiprocess
env:
- name: RUN_ID
value: PLACEHOLDER
resources:
limits:
nvidia.com/gpu: 8
volumeMounts:
- name: s3-storage
mountPath: /output
subPath: maxtext
imagePullSecrets:
- name: PLACEHOLDER
volumes:
- name: s3-storage
persistentVolumeClaim:
claimName: s3-pvc
48 changes: 48 additions & 0 deletions .github/eks-workflow-files/nsys-jax/test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
apiVersion: batch/v1
kind: Job
metadata:
name: PLACEHOLDER
labels:
kueue.x-k8s.io/queue-name: p5-queue
kueue.x-k8s.io/max-exec-time-seconds: "10800"
spec:
template:
spec:
restartPolicy: Never
containers:
- name: nsys-jax
image: PLACEHOLDER
command:
- bash
- -c
- |
set -exo pipefail

LOG_DIR="/output/${RUN_ID}"
mkdir -p ${LOG_DIR}

# nsys-jax is already installed, this is just adding the test dependencies
pip install pytest-reportlog nsys-jax[test]
# abuse knowledge that nsys-jax is installed editable, so the tests exist
test_path=$(python -c 'import importlib.resources; print(importlib.resources.files("nsys_jax").joinpath("..", "tests").resolve())')
pytest \
--basetemp=${LOG_DIR}/pytest-tmp \
--report-log=${LOG_DIR}/pytest-report.jsonl \
"${test_path}" \
2>&1 | tee ${LOG_DIR}/test-nsys-jax.log
env:
- name: RUN_ID
value: PLACEHOLDER
resources:
limits:
nvidia.com/gpu: 8
volumeMounts:
- name: s3-storage
mountPath: /output
subPath: nsys-jax
imagePullSecrets:
- name: PLACEHOLDER
volumes:
- name: s3-storage
persistentVolumeClaim:
claimName: s3-pvc
Loading
Loading