diff --git a/.github/actions/build-container/action.yml b/.github/actions/build-container/action.yml index f53e14bea..b945ca10f 100644 --- a/.github/actions/build-container/action.yml +++ b/.github/actions/build-container/action.yml @@ -54,6 +54,14 @@ inputs: description: "URL of the Bazel remote cache to use for building the image" required: true default: "" + ENABLE_BAZEL_DISK_CACHE: + description: "Enable Bazel disk cache via actions/cache" + required: false + default: "false" + ENABLE_BAZEL_REPO_CACHE: + description: "Enable Bazel repository cache via actions/cache" + required: false + default: "false" outputs: DOCKER_TAG_MEALKIT: @@ -72,6 +80,15 @@ runs: run: | echo 'UPLD_IMAGE=ghcr.io/nvidia/jax-toolbox-internal' >> $GITHUB_ENV echo "BADGE_FILENAME_FULL=${{ inputs.BADGE_FILENAME }}-${{ inputs.ARCHITECTURE }}.json" >> $GITHUB_ENV + # Cap Docker client API version to match the daemon on NVKS runners + echo 'DOCKER_API_VERSION=1.43' >> $GITHUB_ENV + # When disk cache is enabled use the BuildKit cache mount path; + # otherwise fall back to the remote cache URL (internal infra runners). + if [[ "${{ inputs.ENABLE_BAZEL_DISK_CACHE }}" == "true" ]]; then + echo 'BAZEL_CACHE_ARG=BAZEL_CACHE=/cache/bazel-disk' >> $GITHUB_ENV + else + echo 'BAZEL_CACHE_ARG=BAZEL_CACHE=${{ inputs.bazel-remote-cache-url }}' >> $GITHUB_ENV + fi - name: Setup SSH id: setup-ssh @@ -91,8 +108,7 @@ runs: uses: docker/setup-buildx-action@v3 with: driver-opts: | - image=moby/buildkit:v0.12.1 - version: v0.30.1 + image=moby/buildkit:v0.19.0 - name: Download nsys-jax version.py uses: actions/download-artifact@v4 @@ -106,6 +122,38 @@ runs: mv version.py .github/container/nsys_jax/nsys_jax/ cat .github/container/nsys_jax/nsys_jax/version.py + # BAZEL CACHE RESTORE + - name: Restore Bazel disk cache + if: inputs.ENABLE_BAZEL_DISK_CACHE == 'true' + uses: actions/cache/restore@v4 + with: + path: /tmp/bazel-disk.tar + key: bazel-disk-cache-${{ inputs.ARCHITECTURE }}-${{ github.run_id }} + restore-keys: | + bazel-disk-cache-${{ inputs.ARCHITECTURE }}- + + - name: Restore Bazel repo cache + if: inputs.ENABLE_BAZEL_REPO_CACHE == 'true' + uses: actions/cache/restore@v4 + with: + path: /tmp/bazel-repo.tar + key: bazel-repo-cache-${{ inputs.ARCHITECTURE }}-${{ github.run_id }} + restore-keys: | + bazel-repo-cache-${{ inputs.ARCHITECTURE }}- + + # Extract restored tars into seed dirs; create empty dirs on first run + - name: Prepare Bazel cache seed directories + shell: bash + run: | + mkdir -p /tmp/bazel-disk-cache + if [[ -f /tmp/bazel-disk.tar ]]; then + tar -xf /tmp/bazel-disk.tar -C /tmp/bazel-disk-cache + fi + mkdir -p /tmp/bazel-repo-cache + if [[ -f /tmp/bazel-repo.tar ]]; then + tar -xf /tmp/bazel-repo.tar -C /tmp/bazel-repo-cache + fi + # MEALKIT BUILD - name: Set docker metadata - mealkit id: mealkit-metadata @@ -134,9 +182,11 @@ runs: ssh: default secret-files: | "SSH_KNOWN_HOSTS=${{ steps.setup-ssh.outputs.known-hosts-file }}" + build-contexts: | + bazel-disk-seed=/tmp/bazel-disk-cache build-args: | BASE_IMAGE=${{ inputs.BASE_IMAGE }} - BAZEL_CACHE=${{ inputs.bazel-remote-cache-url }} + ${{ env.BAZEL_CACHE_ARG }} BUILD_DATE=${{ inputs.BUILD_DATE }} ${{ inputs.EXTRA_BUILD_ARGS }} # FINAL IMAGE BUILD @@ -169,10 +219,65 @@ runs: "SSH_KNOWN_HOSTS=${{ steps.setup-ssh.outputs.known-hosts-file }}" build-args: | BASE_IMAGE=${{ inputs.BASE_IMAGE }} - BAZEL_CACHE=${{ inputs.bazel-remote-cache-url }} + ${{ env.BAZEL_CACHE_ARG }} BUILD_DATE=${{ inputs.BUILD_DATE }} ${{ inputs.EXTRA_BUILD_ARGS }} + # BAZEL CACHE EXPORT + # Snapshots are captured first; prune runs after to free space before upload. + # type=tar streams a single archive instead of per-file copies + - name: Export Bazel disk cache + if: inputs.ENABLE_BAZEL_DISK_CACHE == 'true' + uses: docker/build-push-action@v5 + with: + context: ${{ inputs.DOCKER_CONTEXT }} + push: false + file: ${{ inputs.DOCKERFILE }} + platforms: linux/${{ inputs.ARCHITECTURE }} + target: bazel-disk-export + outputs: type=tar,dest=/tmp/bazel-disk.tar + build-contexts: | + bazel-disk-seed=/tmp/bazel-disk-cache + build-args: | + BASE_IMAGE=${{ inputs.BASE_IMAGE }} + BUILD_DATE=${{ inputs.BUILD_DATE }} + ${{ inputs.EXTRA_BUILD_ARGS }} + + - name: Export Bazel repo cache + if: inputs.ENABLE_BAZEL_REPO_CACHE == 'true' + uses: docker/build-push-action@v5 + with: + context: ${{ inputs.DOCKER_CONTEXT }} + push: false + file: ${{ inputs.DOCKERFILE }} + platforms: linux/${{ inputs.ARCHITECTURE }} + target: bazel-repo-export + outputs: type=tar,dest=/tmp/bazel-repo.tar + build-args: | + BASE_IMAGE=${{ inputs.BASE_IMAGE }} + BUILD_DATE=${{ inputs.BUILD_DATE }} + ${{ inputs.EXTRA_BUILD_ARGS }} + + # Prune layer cache after snapshots are captured to free disk space before upload + - name: Prune BuildKit layer cache before upload + if: inputs.ENABLE_BAZEL_DISK_CACHE == 'true' || inputs.ENABLE_BAZEL_REPO_CACHE == 'true' + shell: bash + run: docker buildx prune --force + + - name: Save Bazel disk cache + if: inputs.ENABLE_BAZEL_DISK_CACHE == 'true' + uses: actions/cache/save@v4 + with: + path: /tmp/bazel-disk.tar + key: bazel-disk-cache-${{ inputs.ARCHITECTURE }}-${{ github.run_id }} + + - name: Save Bazel repo cache + if: inputs.ENABLE_BAZEL_REPO_CACHE == 'true' + uses: actions/cache/save@v4 + with: + path: /tmp/bazel-repo.tar + key: bazel-repo-cache-${{ inputs.ARCHITECTURE }}-${{ github.run_id }} + # SITREP GENERATION - name: Generate sitrep if: "!cancelled()" diff --git a/.github/container/Dockerfile.jax b/.github/container/Dockerfile.jax index da7c2a29e..f964df995 100644 --- a/.github/container/Dockerfile.jax +++ b/.github/container/Dockerfile.jax @@ -19,14 +19,24 @@ ARG SRC_PATH_TRANSFORMER_ENGINE=/opt/transformer-engine ARG GIT_USER_NAME="JAX Toolbox" ARG GIT_USER_EMAIL=jax@nvidia.com -ARG BAZEL_CACHE=/tmp +ARG BAZEL_CACHE=/cache/bazel-disk ARG BUILD_DATE +############################################################################### +## Bazel disk cache seed (overridden via --build-context on cache hit) +############################################################################### + +# On first run this is empty (FROM scratch). When actions/cache restores a +# previous disk cache to /tmp/bazel-disk-cache on the runner, the caller passes +# --build-context bazel-disk-seed=/tmp/bazel-disk-cache to inject it. +FROM scratch AS bazel-disk-seed + ############################################################################### ## Build JAX ############################################################################### FROM ${BASE_IMAGE} AS builder +ARG TARGETARCH ARG URLREF_JAX ARG URLREF_TRANSFORMER_ENGINE ARG URLREF_XLA @@ -54,9 +64,14 @@ RUN ARCH="$(dpkg --print-architecture)" && \ chmod +x /usr/local/bin/bazel # Populate ${BUILD_PATH_JAXLIB} with editable wheels; --no-install because # (a) this is the builder stage, and (b) pip-finalize.sh does the install -RUN mkdir -p /builder/extra-targets/{bin,python} && \ +RUN --mount=type=cache,id=bazel-disk-${TARGETARCH},target=/cache/bazel-disk,sharing=locked \ + --mount=type=cache,id=bazel-repo-${TARGETARCH},target=/cache/bazel-repo,sharing=locked \ + --mount=type=bind,from=bazel-disk-seed,source=.,target=/tmp/bazel-disk-seed,readonly \ + cp -a /tmp/bazel-disk-seed/. /cache/bazel-disk/ 2>/dev/null || true && \ + mkdir -p /builder/extra-targets/{bin,python} && \ build-jax.sh \ --bazel-cache ${BAZEL_CACHE} \ + --build-param --bazel_options=--repository_cache=/cache/bazel-repo \ --build-path-jaxlib ${BUILD_PATH_JAXLIB} \ --extra-targets "${EXTRA_BAZEL_TARGETS}" \ --extra-target-dest /builder/extra-targets \ @@ -148,3 +163,29 @@ RUN install-nsys-jax.sh ${SRC_PATH_NSYS_JAX} FROM mealkit AS final RUN pip-finalize.sh + +############################################################################### +## Bazel cache export stages (used by CI to persist caches via actions/cache) +############################################################################### + +# ARG BUILD_DATE ensures this always re-executes (never a registry cache hit), +# so the snapshot always reflects the current run's cache mount content. +FROM ${BASE_IMAGE} AS bazel-disk-snapshot +ARG TARGETARCH +ARG BUILD_DATE +RUN --mount=type=cache,id=bazel-disk-${TARGETARCH},target=/cache/bazel-disk,sharing=locked,readonly \ + mkdir -p /bazel-disk-snapshot && \ + cp -rp /cache/bazel-disk/. /bazel-disk-snapshot/ + +FROM scratch AS bazel-disk-export +COPY --from=bazel-disk-snapshot /bazel-disk-snapshot / + +FROM ${BASE_IMAGE} AS bazel-repo-snapshot +ARG TARGETARCH +ARG BUILD_DATE +RUN --mount=type=cache,id=bazel-repo-${TARGETARCH},target=/cache/bazel-repo,sharing=locked,readonly \ + mkdir -p /bazel-repo-snapshot && \ + cp -rp /cache/bazel-repo/. /bazel-repo-snapshot/ + +FROM scratch AS bazel-repo-export +COPY --from=bazel-repo-snapshot /bazel-repo-snapshot / diff --git a/.github/eks-workflow-files/jax/test.yml b/.github/eks-workflow-files/jax/test.yml new file mode 100644 index 000000000..c6b6c4bbe --- /dev/null +++ b/.github/eks-workflow-files/jax/test.yml @@ -0,0 +1,48 @@ +apiVersion: batch/v1 +kind: Job +metadata: + name: PLACEHOLDER + labels: + kueue.x-k8s.io/queue-name: p5-queue + kueue.x-k8s.io/max-exec-time-seconds: "10800" +spec: + template: + spec: + restartPolicy: Never + containers: + - name: jax + image: PLACEHOLDER + command: + - bash + - -c + - | + set -exo pipefail + + LOG_DIR="/output/${RUN_ID}" + mkdir -p ${LOG_DIR} + + # backend-independent tests + test-jax.sh -b backend-independent 2>&1 | tee ${LOG_DIR}/test-backend-independent.log + + # single-gpu tests + nvidia-cuda-mps-control -d + test-jax.sh -b single-gpu 2>&1 | tee ${LOG_DIR}/test-single-gpu.log + + # multi-gpu tests + test-jax.sh -b multi-gpu 2>&1 | tee ${LOG_DIR}/test-multi-gpu.log + env: + - name: RUN_ID + value: PLACEHOLDER + resources: + limits: + nvidia.com/gpu: 8 + volumeMounts: + - name: s3-storage + mountPath: /output + subPath: jax + imagePullSecrets: + - name: PLACEHOLDER + volumes: + - name: s3-storage + persistentVolumeClaim: + claimName: s3-pvc diff --git a/.github/eks-workflow-files/maxtext/test.yml b/.github/eks-workflow-files/maxtext/test.yml new file mode 100644 index 000000000..455cf2f4d --- /dev/null +++ b/.github/eks-workflow-files/maxtext/test.yml @@ -0,0 +1,69 @@ +apiVersion: batch/v1 +kind: Job +metadata: + name: PLACEHOLDER + labels: + kueue.x-k8s.io/queue-name: p5-queue + kueue.x-k8s.io/max-exec-time-seconds: "10800" +spec: + template: + spec: + restartPolicy: Never + containers: + - name: maxtext + image: PLACEHOLDER + command: + - bash + - -c + - | + set -exo pipefail + + LOG_DIR="/output/${RUN_ID}" + mkdir -p ${LOG_DIR} + + # single-process-multi-device: PP=1, DP=1, FSDP=2, TP=4 + test-maxtext.sh \ + --output ${LOG_DIR}/1DP2FSDP4TP1PP_single_process \ + --dtype bfloat16 \ + --mem-fraction 0.65 \ + --decoder-block default \ + --attn-type dot_product \ + --batch-per-gpu 2 \ + --steps 10 \ + --pipeline-parallel 1 \ + --data-parallel 1 \ + --fsdp 2 \ + --tensor-parallel 4 \ + --nodes 1 + + # multi-process: PP=1, DP=2, FSDP=2, TP=2 + test-maxtext.sh \ + --output ${LOG_DIR}/2DP2FSDP2TP1PP \ + --dtype bfloat16 \ + --mem-fraction 0.65 \ + --decoder-block default \ + --attn-type dot_product \ + --batch-per-gpu 2 \ + --steps 10 \ + --pipeline-parallel 1 \ + --data-parallel 2 \ + --fsdp 2 \ + --tensor-parallel 2 \ + --nodes 1 \ + --multiprocess + env: + - name: RUN_ID + value: PLACEHOLDER + resources: + limits: + nvidia.com/gpu: 8 + volumeMounts: + - name: s3-storage + mountPath: /output + subPath: maxtext + imagePullSecrets: + - name: PLACEHOLDER + volumes: + - name: s3-storage + persistentVolumeClaim: + claimName: s3-pvc diff --git a/.github/eks-workflow-files/nsys-jax/test.yml b/.github/eks-workflow-files/nsys-jax/test.yml new file mode 100644 index 000000000..789eca17c --- /dev/null +++ b/.github/eks-workflow-files/nsys-jax/test.yml @@ -0,0 +1,48 @@ +apiVersion: batch/v1 +kind: Job +metadata: + name: PLACEHOLDER + labels: + kueue.x-k8s.io/queue-name: p5-queue + kueue.x-k8s.io/max-exec-time-seconds: "10800" +spec: + template: + spec: + restartPolicy: Never + containers: + - name: nsys-jax + image: PLACEHOLDER + command: + - bash + - -c + - | + set -exo pipefail + + LOG_DIR="/output/${RUN_ID}" + mkdir -p ${LOG_DIR} + + # nsys-jax is already installed, this is just adding the test dependencies + pip install pytest-reportlog nsys-jax[test] + # abuse knowledge that nsys-jax is installed editable, so the tests exist + test_path=$(python -c 'import importlib.resources; print(importlib.resources.files("nsys_jax").joinpath("..", "tests").resolve())') + pytest \ + --basetemp=${LOG_DIR}/pytest-tmp \ + --report-log=${LOG_DIR}/pytest-report.jsonl \ + "${test_path}" \ + 2>&1 | tee ${LOG_DIR}/test-nsys-jax.log + env: + - name: RUN_ID + value: PLACEHOLDER + resources: + limits: + nvidia.com/gpu: 8 + volumeMounts: + - name: s3-storage + mountPath: /output + subPath: nsys-jax + imagePullSecrets: + - name: PLACEHOLDER + volumes: + - name: s3-storage + persistentVolumeClaim: + claimName: s3-pvc diff --git a/.github/workflows/_build_base.yaml b/.github/workflows/_build_base.yaml index 930ae9cfa..3251d5dd9 100644 --- a/.github/workflows/_build_base.yaml +++ b/.github/workflows/_build_base.yaml @@ -58,7 +58,7 @@ permissions: jobs: build-base: - runs-on: [self-hosted, "${{ inputs.ARCHITECTURE }}", small] + runs-on: ${{ inputs.ARCHITECTURE == 'amd64' && 'linux-amd64-cpu16m' || 'linux-arm64-cpu16m' }} env: BADGE_FILENAME_FULL: ${{ inputs.BADGE_FILENAME }}-${{ inputs.ARCHITECTURE }}.json outputs: @@ -137,7 +137,7 @@ jobs: BUILD_DATE=${{ inputs.BUILD_DATE }} JAX_TOOLBOX_REF=${{ github.head_ref || github.sha }} ${{ inputs.BASE_IMAGE != 'latest' && format('BASE_IMAGE={0}', inputs.BASE_IMAGE) || '' }} - + - name: Generate sitrep if: "!cancelled()" shell: bash -x -e {0} diff --git a/.github/workflows/_ci.yaml b/.github/workflows/_ci.yaml index ca6bd3c41..251352f68 100644 --- a/.github/workflows/_ci.yaml +++ b/.github/workflows/_ci.yaml @@ -63,7 +63,7 @@ jobs: build-jax: needs: build-base - runs-on: [self-hosted, "${{ inputs.ARCHITECTURE }}", "large"] + runs-on: ${{ inputs.ARCHITECTURE == 'amd64' && 'linux-amd64-cpu32m' || 'linux-arm64-cpu32m' }} steps: - name: Checkout repository uses: actions/checkout@v4 @@ -82,7 +82,9 @@ jobs: ssh-private-key: ${{ secrets.SSH_PRIVATE_KEY }} ssh-known-hosts: ${{ vars.SSH_KNOWN_HOSTS }} github-token: ${{ secrets.GITHUB_TOKEN }} - bazel-remote-cache-url: ${{ vars.BAZEL_REMOTE_CACHE_URL }} + bazel-remote-cache-url: "" + ENABLE_BAZEL_DISK_CACHE: 'true' + ENABLE_BAZEL_REPO_CACHE: 'true' EXTRA_BUILD_ARGS: | URLREF_JAX=${{ fromJson(inputs.SOURCE_URLREFS).JAX }} URLREF_XLA=${{ fromJson(inputs.SOURCE_URLREFS).XLA }} @@ -94,7 +96,7 @@ jobs: build-equinox: needs: build-jax - runs-on: [self-hosted, "${{ inputs.ARCHITECTURE }}", "small"] + runs-on: ${{ inputs.ARCHITECTURE == 'amd64' && 'linux-amd64-cpu16m' || 'linux-arm64-cpu16m' }} outputs: DOCKER_TAG_MEALKIT: ${{ steps.build-equinox.outputs.DOCKER_TAG_MEALKIT }} DOCKER_TAG_FINAL: ${{ steps.build-equinox.outputs.DOCKER_TAG_FINAL }} @@ -116,13 +118,13 @@ jobs: ssh-private-key: ${{ secrets.SSH_PRIVATE_KEY }} ssh-known-hosts: ${{ vars.SSH_KNOWN_HOSTS }} github-token: ${{ secrets.GITHUB_TOKEN }} - bazel-remote-cache-url: ${{ vars.BAZEL_REMOTE_CACHE_URL }} + bazel-remote-cache-url: "" EXTRA_BUILD_ARGS: | URLREF_EQUINOX=${{ fromJson(inputs.SOURCE_URLREFS).EQUINOX }} build-maxtext: needs: build-jax - runs-on: [self-hosted, "${{ inputs.ARCHITECTURE }}", "small"] + runs-on: ${{ inputs.ARCHITECTURE == 'amd64' && 'linux-amd64-cpu16m' || 'linux-arm64-cpu16m' }} outputs: DOCKER_TAG_MEALKIT: ${{ steps.build-maxtext.outputs.DOCKER_TAG_MEALKIT }} DOCKER_TAG_FINAL: ${{ steps.build-maxtext.outputs.DOCKER_TAG_FINAL }} @@ -144,13 +146,13 @@ jobs: ssh-private-key: ${{ secrets.SSH_PRIVATE_KEY }} ssh-known-hosts: ${{ vars.SSH_KNOWN_HOSTS }} github-token: ${{ secrets.GITHUB_TOKEN }} - bazel-remote-cache-url: ${{ vars.BAZEL_REMOTE_CACHE_URL }} + bazel-remote-cache-url: "" EXTRA_BUILD_ARGS: | URLREF_MAXTEXT=${{ fromJson(inputs.SOURCE_URLREFS).MAXTEXT }} build-torchax: needs: build-jax - runs-on: [self-hosted, "${{ inputs.ARCHITECTURE }}", "small"] + runs-on: ${{ inputs.ARCHITECTURE == 'amd64' && 'linux-amd64-cpu16m' || 'linux-arm64-cpu16m' }} outputs: DOCKER_TAG_MEALKIT: ${{ steps.build-torchax.outputs.DOCKER_TAG_MEALKIT }} DOCKER_TAG_FINAL: ${{ steps.build-torchax.outputs.DOCKER_TAG_FINAL }} @@ -172,13 +174,13 @@ jobs: ssh-private-key: ${{ secrets.SSH_PRIVATE_KEY }} ssh-known-hosts: ${{ vars.SSH_KNOWN_HOSTS }} github-token: ${{ secrets.GITHUB_TOKEN }} - bazel-remote-cache-url: ${{ vars.BAZEL_REMOTE_CACHE_URL }} + bazel-remote-cache-url: "" EXTRA_BUILD_ARGS: | URLREF_TORCHAX=${{ fromJson(inputs.SOURCE_URLREFS).TORCHAX }} build-axlearn: needs: build-jax - runs-on: [self-hosted, "${{ inputs.ARCHITECTURE }}", "large"] + runs-on: ${{ inputs.ARCHITECTURE == 'amd64' && 'linux-amd64-cpu16m' || 'linux-arm64-cpu16m' }} outputs: DOCKER_TAG_MEALKIT: ${{ steps.build-axlearn.outputs.DOCKER_TAG_MEALKIT }} DOCKER_TAG_FINAL: ${{ steps.build-axlearn.outputs.DOCKER_TAG_FINAL }} @@ -200,7 +202,7 @@ jobs: ssh-private-key: ${{ secrets.SSH_PRIVATE_KEY }} ssh-known-hosts: ${{ vars.SSH_KNOWN_HOSTS }} github-token: ${{ secrets.GITHUB_TOKEN }} - bazel-remote-cache-url: ${{ vars.BAZEL_REMOTE_CACHE_URL }} + bazel-remote-cache-url: "" EXTRA_BUILD_ARGS: | URLREF_AXLEARN=${{ fromJson(inputs.SOURCE_URLREFS).AXLEARN }} @@ -238,7 +240,7 @@ jobs: echo "TAGS=${TAGS}" >> $GITHUB_OUTPUT - test-jax: + test-jax-eks: needs: build-jax if: >- inputs.ARCHITECTURE == 'amd64' && @@ -246,119 +248,102 @@ jobs: inputs.MODE == 'full' || inputs.MODE == 'jax' ) - uses: ./.github/workflows/_test_unit.yaml - with: - TEST_NAME: jax - EXECUTE: | - docker run -i --shm-size=1g --gpus all \ - ${{ needs.build-jax.outputs.DOCKER_TAG_FINAL }} \ - bash <<"EOF" |& tee test-backend-independent.log - test-jax.sh -b backend-independent - EOF - docker run -i --shm-size=1g --gpus all \ - ${{ needs.build-jax.outputs.DOCKER_TAG_FINAL }} \ - bash <<"EOF" |& tee test-single-gpu.log - nvidia-cuda-mps-control -d - test-jax.sh -b single-gpu - EOF - docker run -i --shm-size=1g --gpus all \ - ${{ needs.build-jax.outputs.DOCKER_TAG_FINAL }} \ - bash <<"EOF" |& tee test-multi-gpu.log - nvidia-cuda-mps-control -d - test-jax.sh -b multi-gpu - EOF - STATISTICS_SCRIPT: | - errors=$(cat test-*.log | grep -c 'ERROR:' || true) - failed_tests=$(cat test-*.log | grep -c 'FAILED in' || true) - passed_tests=$(cat test-*.log | grep -c 'PASSED in' || true) + runs-on: eks + env: + JAX_DOCKER_IMAGE: ${{ needs.build-jax.outputs.DOCKER_TAG_FINAL }} + JOB_NAME: jax-${{ github.run_id }} + steps: + - name: Check out the repository + uses: actions/checkout@v6 + - name: Login to GitHub Container Registry + uses: docker/login-action@v4 + with: + registry: ghcr.io + username: ${{ github.repository_owner }} + password: ${{ secrets.GITHUB_TOKEN }} + - name: K8s GHCR store and delete token + id: store-token + uses: ./.github/actions/store-delete-k8s-ghcr + - name: Configure JAX test job + run: | + yq -i ea ' + select(di == 0).metadata.name = strenv(JOB_NAME) + | select(di == 0).spec.template.spec.containers[0].image = strenv(JAX_DOCKER_IMAGE) + | select(di == 0).spec.template.spec.containers[0].env[0].value = "${{ github.run_id }}" + | select(di == 0).spec.template.spec.imagePullSecrets[0].name = "${{ steps.store-token.outputs.token-name }}"' \ + .github/eks-workflow-files/jax/test.yml + git diff .github/eks-workflow-files/jax/test.yml + - name: Submit & delete JAX unit test job + uses: ./.github/actions/submit-delete-k8s-job + with: + job-config-file: ".github/eks-workflow-files/jax/test.yml" + job-name: ${{ env.JOB_NAME }} + - name: Download logs from S3 + id: log-s3 + if: ${{ !cancelled() }} + run: | + mkdir -p jax-output + aws s3 cp s3://jax-toolbox-eks-output/jax/${{ github.run_id }}/ jax-output/ --recursive + + errors=$(cat jax-output/test-*.log | grep -c 'ERROR:' || true) + failed_tests=$(cat jax-output/test-*.log | grep -c 'FAILED in' || true) + passed_tests=$(cat jax-output/test-*.log | grep -c 'PASSED in' || true) total_tests=$((failed_tests + passed_tests)) - echo "TOTAL_TESTS=${total_tests}" >> $GITHUB_OUTPUT - echo "ERRORS=${errors}" >> $GITHUB_OUTPUT - echo "PASSED_TESTS=${passed_tests}" >> $GITHUB_OUTPUT - echo "FAILED_TESTS=${failed_tests}" >> $GITHUB_OUTPUT - ARTIFACTS: | - test-backend-independent.log - test-multi-gpu.log - test-single-gpu.log - secrets: inherit - test-nsys-jax: - needs: build-jax - if: >- - inputs.ARCHITECTURE == 'amd64' && - ( - inputs.MODE == 'full' || - inputs.MODE == 'jax' - ) - uses: ./.github/workflows/_test_unit.yaml - with: - TEST_NAME: nsys-jax - EXECUTE: | - set -o pipefail - mkdir -p output-results - docker run -i --shm-size=1g --gpus all \ - -v $PWD/output-results:/opt/output \ - ${{ needs.build-jax.outputs.DOCKER_TAG_FINAL }} \ - bash <<"EOF" |& tee test-nsys-jax.log - # nsys-jax is already installed, this is just adding the test dependencies - pip install pytest-reportlog nsys-jax[test] - # abuse knowledge that nsys-jax is installed editable, so the tests exist - test_path=$(python -c 'import importlib.resources; print(importlib.resources.files("nsys_jax").joinpath("..", "tests").resolve())') - pytest --basetemp=/opt/output/pytest-tmp --report-log=/opt/output/pytest-report.jsonl "${test_path}" - chmod -R a+rwX /opt/output - EOF - STATISTICS_SCRIPT: | - summary_line=$(tail -n1 test-nsys-jax.log) - num_errors=$(echo $summary_line | grep -oE '[0-9]+ error' | awk '{print $1} END { if (!NR) print 0}') - passed_tests=$(cat output-results/pytest-report.jsonl | jq -r 'select(."$report_type" == "TestReport" and .when == "call" and .outcome == "passed") | .outcome' | wc -l) - failed_tests=$(cat output-results/pytest-report.jsonl | jq -r 'select(."$report_type" == "TestReport" and .when == "call" and .outcome == "failed") | .outcome' | wc -l) - total_tests=$(( passed_tests + failed_tests )) - echo "TOTAL_TESTS=${total_tests}" >> $GITHUB_OUTPUT - echo "ERRORS=${num_errors}" >> $GITHUB_OUTPUT - echo "PASSED_TESTS=${passed_tests}" >> $GITHUB_OUTPUT - echo "FAILED_TESTS=${failed_tests}" >> $GITHUB_OUTPUT - ARTIFACTS: | - # pytest-driven part - test-nsys-jax.log - output-results/pytest-report.jsonl - output-results/pytest-tmp/ - secrets: inherit + echo "Passed tests: $passed_tests" + echo "Failed tests: $failed_tests" + echo "Total tests: $total_tests" + echo "PASSED_TESTS=$passed_tests" >> $GITHUB_OUTPUT + echo "FAILED_TESTS=$failed_tests" >> $GITHUB_OUTPUT + echo "TOTAL_TESTS=$total_tests" >> $GITHUB_OUTPUT + echo "ERRORS=$errors" >> $GITHUB_OUTPUT - # test-nsys-jax generates several fresh .zip archive outputs by running nsys-jax with real GPU hardware; this test - # runs on a regular GitHub Actions runner and checks that offline post-processing works in an environment that does - # not already have nsys-jax installed - test-nsys-jax-archive: - needs: test-nsys-jax - if: >- - inputs.ARCHITECTURE == 'amd64' && - ( - inputs.MODE == 'full' || - inputs.MODE == 'jax' - ) - strategy: - matrix: - os: [ubuntu-22.04, ubuntu-24.04, macOS-latest] - runs-on: ${{ matrix.os }} - steps: - - name: Download nsys-jax output .zip files - uses: actions/download-artifact@v4 - with: - name: nsys-jax-unit-test-A100 - - name: Extract archives and execute install scripts + if [[ $failed_tests -gt 0 ]] || [[ $errors -gt 0 ]]; then + exit 1 + fi + - name: Generate sitrep + id: sitrep + if: ${{ !cancelled() }} + shell: bash -x -e {0} run: | - pip install virtualenv # for install.sh - for zip in $(ls *.zip); do - ZIP="${PWD}/${zip}" - pushd $(mktemp -d) - unzip "${ZIP}" - ls -l - # TODO: verify this isn't needed, or make sure it isn't needed - chmod 755 install.sh - # Run the notebook with IPython, not Jupyter Lab, so it exits and prints something informative to stdout - # Skip executing Jupyter lab - NSYS_JAX_JUPYTER_EXECUTE_NOT_LAB=1 ./install.sh - popd - done + # bring in utility functions + source .github/workflows/scripts/to_json.sh + + badge_label='JAX EKS unittest (8)' + + total_tests=${{ steps.log-s3.outputs.TOTAL_TESTS }} \ + failed_tests=${{ steps.log-s3.outputs.FAILED_TESTS }} \ + passed_tests=${{ steps.log-s3.outputs.PASSED_TESTS }} \ + errors=${{ steps.log-s3.outputs.ERRORS }} \ + summary="All tests: $total_tests. Passed: $passed_tests. Failed: $failed_tests." \ + badge_message="Passed $passed_tests out of $total_tests." \ + badge_color="brightgreen" + if [ "$failed_tests" -gt 0 ] || [ "$errors" -gt 0 ]; then + badge_color="red" + fi \ + + to_json \ + summary \ + errors total_tests passed_tests failed_tests \ + badge_label badge_color badge_message \ + > sitrep.json + + schemaVersion=1 \ + label="${badge_label}" \ + message="Passed $passed_tests out of $total_tests." \ + color=$badge_color \ + to_json schemaVersion label message color \ + > badge-jax-unit-test-eks.json + + - name: Upload artifacts + if: ${{ !cancelled() }} + uses: actions/upload-artifact@v4 + with: + name: "jax-unit-test-H100-eks" + path: | + sitrep.json + badge-jax-unit-test-eks.json + jax-output/* test-nsys-jax-eks: needs: build-jax @@ -451,84 +436,105 @@ jobs: CI_NAME: jax-cutlass secrets: inherit - test-te-a100: - needs: build-jax - secrets: inherit + test-maxtext-eks: + needs: build-maxtext if: >- inputs.ARCHITECTURE == 'amd64' && ( inputs.MODE == 'full' || - inputs.MODE == 'te' + inputs.MODE == 'maxtext' ) - uses: ./.github/workflows/_test_unit.yaml - with: - TEST_NAME: te - EXECUTE: | - docker run -i --gpus all --shm-size=1g -v $PWD:/log \ - ${{ needs.build-jax.outputs.DOCKER_TAG_FINAL }} \ - bash <<"EOF" |& tee test-te.log - set -xu -o pipefail - - LOG_DIR=/log - - pip install pytest-reportlog pytest-xdist - # Start MPS daemon - nvidia-cuda-mps-control -d - # TE's default is slightly different, without the hyphen - export TE_PATH=${SRC_PATH_TRANSFORMER_ENGINE} - # 1 GPU per worker, 3 workers per GPU - pytest-xdist.sh 1 3 ${LOG_DIR}/pytest-report-L0-unittest.jsonl bash ${TE_PATH}/qa/L0_jax_unittest/test.sh - ## 8 GPUs per worker, 1 worker per GPU. pytest-xdist.sh allows aggregation - ## into a single .jsonl file of results from multiple pytest invocations - ## inside the test.sh script, so it's useful even with a single worker per - ## device. - pytest-xdist.sh 8 1 ${LOG_DIR}/pytest-report-L0-distributed-unittest.jsonl bash ${TE_PATH}/qa/L0_jax_distributed_unittest/test.sh - - # merge the log files - cat \ - ${LOG_DIR}/pytest-report-L0-unittest.jsonl \ - ${LOG_DIR}/pytest-report-L0-distributed-unittest.jsonl \ - > ${LOG_DIR}/pytest-report.jsonl - - EOF - STATISTICS_SCRIPT: | - report_json=pytest-report.jsonl - summary_line=$(tail -n1 test-te.log) - errors=$(echo $summary_line | grep -oE '[0-9]+ error' | awk '{print $1} END { if (!NR) print 0}') - passed_tests=$(cat $report_json | jq -r 'select(."$report_type" == "TestReport" and .when == "call" and .outcome == "passed") | .outcome' | wc -l) - failed_tests=$(cat $report_json | jq -r 'select(."$report_type" == "TestReport" and .when == "call" and .outcome == "failed") | .outcome' | wc -l) - total_tests=$((failed_tests + passed_tests)) - echo "TOTAL_TESTS=${total_tests}" >> $GITHUB_OUTPUT - echo "ERRORS=${errors}" >> $GITHUB_OUTPUT - echo "PASSED_TESTS=${passed_tests}" >> $GITHUB_OUTPUT - echo "FAILED_TESTS=${failed_tests}" >> $GITHUB_OUTPUT - - echo "$failed_tests tests failed" - if [[ $failed_tests -gt 0 ]]; then - exit 1 - else - exit 0 + runs-on: [eks] + env: + MAXTEXT_DOCKER_IMAGE: ${{ needs.build-maxtext.outputs.DOCKER_TAG_FINAL }} + JOB_NAME: maxtext-${{ github.run_id }} + steps: + - name: Check out the repository + uses: actions/checkout@v4 + - name: Login to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.repository_owner }} + password: ${{ secrets.GITHUB_TOKEN }} + - name: K8s GHCR store and delete token + id: store-token + uses: ./.github/actions/store-delete-k8s-ghcr + - name: Configure maxtext test job + run: | + yq -i ea ' + select(di == 0).metadata.name = strenv(JOB_NAME) + | select(di == 0).spec.template.spec.containers[0].image = strenv(MAXTEXT_DOCKER_IMAGE) + | select(di == 0).spec.template.spec.containers[0].env[0].value = "${{ github.run_id }}" + | select(di == 0).spec.template.spec.imagePullSecrets[0].name = "${{ steps.store-token.outputs.token-name }}"' \ + .github/eks-workflow-files/maxtext/test.yml + git diff .github/eks-workflow-files/maxtext/test.yml + - name: Submit & delete maxtext test job + uses: ./.github/actions/submit-delete-k8s-job + with: + job-config-file: ".github/eks-workflow-files/maxtext/test.yml" + job-name: ${{ env.JOB_NAME }} + - name: Download results from S3 + id: s3-download + if: ${{ !cancelled() }} + run: | + mkdir -p maxtext-output + aws s3 cp s3://jax-toolbox-eks-output/maxtext/${{ github.run_id }}/ maxtext-output/ --recursive + - name: Run metrics + id: metrics + if: ${{ !cancelled() }} + run: | + pip install 'numpy<2.0.0' pytest pytest-reportlog tensorboard + RESULTS_DIR=maxtext-output BASELINES_DIR=MAXTEXT/upstream \ + pytest --report-log=report.jsonl .github/workflows/baselines/test_maxtext_metrics.py || true + - name: Generate sitrep + id: sitrep + if: ${{ !cancelled() }} + shell: bash -x -e {0} + run: | + # bring in utility functions + source .github/workflows/scripts/to_json.sh + + badge_label='MaxText EKS' + + passed_tests=$(cat report.jsonl | jq -r 'select(."$report_type" == "TestReport" and .when == "call" and .outcome == "passed") | .outcome' | wc -l) + failed_tests=$(cat report.jsonl | jq -r 'select(."$report_type" == "TestReport" and .when == "call" and .outcome == "failed") | .outcome' | wc -l) + total_tests=$(( passed_tests + failed_tests )) + errors=0 + badge_color="brightgreen" + if [ "$failed_tests" -gt 0 ]; then + badge_color="red" fi - TIMEOUT_MINUTES: 120 - ARTIFACTS: | - test-te.log - pytest-report.jsonl - pytest-report-L0-unittest.jsonl - pytest-report-L0-distributed-unittest.jsonl + total_tests=$total_tests \ + failed_tests=$failed_tests \ + passed_tests=$passed_tests \ + errors=$errors \ + summary="All metrics tests: $total_tests. Passed: $passed_tests. Failed: $failed_tests." \ + badge_message="Passed $passed_tests out of $total_tests." \ + badge_color=$badge_color \ + to_json \ + summary errors total_tests passed_tests failed_tests \ + badge_label badge_color badge_message \ + > sitrep.json - test-maxtext: - needs: build-maxtext - if: >- - inputs.ARCHITECTURE == 'amd64' && - ( - inputs.MODE == 'full' || - inputs.MODE == 'maxtext' - ) - uses: ./.github/workflows/_test_maxtext.yaml - with: - MAXTEXT_IMAGE: ${{ needs.build-maxtext.outputs.DOCKER_TAG_FINAL }} - secrets: inherit + schemaVersion=1 \ + label="${badge_label}" \ + message="Passed $passed_tests out of $total_tests." \ + color=$badge_color \ + to_json schemaVersion label message color \ + > badge-maxtext-test-eks.json + + - name: Upload artifacts + if: ${{ !cancelled() }} + uses: actions/upload-artifact@v4 + with: + name: "maxtext-test-H100-eks" + path: | + sitrep.json + badge-maxtext-test-eks.json + maxtext-output/ + report.jsonl test-maxtext-gke: needs: build-maxtext diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index d711da63d..f199635e8 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -4,12 +4,10 @@ on: schedule: - cron: '30 9 * * *' # Pacific Time 01:30 AM in UTC - cron: '0 0 * * 6' #midnight every Saturday UTC for scale-training - pull_request: - types: - - opened - - reopened - - ready_for_review - - synchronize + push: + # we need this to allow nv-gha-runners to run + branches: + - "**" paths-ignore: - '**.md' - '.github/triage/**' diff --git a/simplefile b/simplefile new file mode 100644 index 000000000..e69de29bb diff --git a/simplefile2trigger b/simplefile2trigger new file mode 100644 index 000000000..e69de29bb