Skip to content

[PG][TENT] Fix first-collective hangs on NVLink/MNNVL bootstrap#1755

Merged
yuechen-sys merged 6 commits intokvcache-ai:mainfrom
KMSorSMS:fix_nvlink
Mar 31, 2026
Merged

[PG][TENT] Fix first-collective hangs on NVLink/MNNVL bootstrap#1755
yuechen-sys merged 6 commits intokvcache-ai:mainfrom
KMSorSMS:fix_nvlink

Conversation

@KMSorSMS
Copy link
Copy Markdown
Contributor

@KMSorSMS KMSorSMS commented Mar 26, 2026

Description

This PR fixes several bootstrap and first-collective hangs when Mooncake PG runs on TENT, especially on NVLink / MNNVL setups.

The issue comes from a combination of initialization-time problems:

  1. The first reduce kernel can be affected by CUDA lazy loading
    During the first collective, the first reduce kernel may be launched only after a resident/spin kernel is already occupying the stream. In this case, the first reduce kernel effectively becomes a tail kernel.

    Based on our investigation, this is closely related to CUDA lazy loading. NVIDIA documents that lazy loading can affect concurrent kernel execution, and explicitly recommends preloading kernels that are expected to run concurrently. NVIDIA also documents that cudaFuncGetAttributes() can be used to force a kernel to load eagerly at runtime.

    Relevant CUDA documentation:

    We observed a particularly surprising symptom here: when the first tail kernel has not been preloaded yet, a cudaMemcpyAsync submitted on another stream can remain blocked until the resident spin kernel is released. We did not find this exact cross-stream memcpy symptom described explicitly in the CUDA documentation, but the following minimal demo reproduces it reliably:

    The demo shows:

    • stream A launches a resident spin kernel, then queues a normal kernel behind it
    • stream B issues a cudaMemcpyAsync
    • without preloading the normal kernel, the copy can remain blocked
    • after preloading the normal kernel once, the copy completes normally

    Mooncake PG can hit the same pattern because the reduce kernel may be the first real CUDA kernel touched by the process.

  2. Host-side control/warmup/sync buffers were registered with a GPU-specific location
    CPU sync regions, warmup regions, and P2P control regions were previously registered using the local GPU location. Under TENT NVLink, this can steer transport selection toward a GPU/NVLink path even though these buffers are plain host memory.

    On MNNVL/fabric setups, such host buffers are not valid NVLink fabric targets. This can break warmup/control traffic and lead to bootstrap hangs.

    This is also related to the earlier location changes in:

    PR [PG] force register local memory for P2P memory regions #1690 adjusted the location in a way that helped the TE path work around the issue, but it did not reflect the actual memory type of these host-side regions. PR [TE] Fix simultaneous open handshake in RdmaEndpoint #1733 later addressed the real TE-side issue. After that fix, the PG-side location for these host control/warmup/sync buffers also needs to be corrected accordingly.

    In other words, these buffers should use a host-compatible registration (kWildcardLocation) rather than inherit the GPU location. This is required so that both TENT and TE choose the correct transport behavior.

  3. RPC port conflicts could be reported as successful startup
    TENT selects an RPC port when starting the local RPC server. If two ranks choose the same port, startup should fail and retry. Previously, async_start() errors were not checked, so a failed bind could still be treated as success. That can publish duplicated server_name / segment names into the store and corrupt later peer connection logic.

  4. Inactive ranks could leave invalid task indices in worker threads
    Some ranks may be skipped because they are inactive or excluded by collective semantics. However, their task indices were not explicitly initialized to an invalid value, and later status polling could still touch those entries.

Solution

This PR applies four fixes:

  1. Preload reduce kernels during backend initialization

    • Add preloadReduceKernels()
    • Call it for CUDA backends before the first real collective
    • Force the reduce kernels to load eagerly via cudaFuncGetAttributes()
    • Avoid the first-launch tail-kernel stall described above
  2. Register host control/sync/warmup memory with wildcard location
    The following host-side regions now use kWildcardLocation instead of a GPU-specific location:

    • CPU sync send/recv regions
    • ConnectionContext warmup send/recv regions
    • P2PProxy control send/recv regions

    These buffers are host-side metadata/control memory and should not inherit the GPU location. This aligns PG with the TE-side fix in PR [TE] Fix simultaneous open handshake in RdmaEndpoint #1733 and makes both TENT and TE choose the correct transport behavior.

  3. Fail fast on RPC port conflicts

    • Check the async_start() result in TENT RPC startup
    • Retry port selection instead of falsely reporting success
    • Avoid duplicated server_name publication
  4. Skip inactive ranks safely in worker-thread polling

    • Initialize skipped ranks with an explicit invalid task id
    • Ignore those entries during transfer-status and sync-status polling

Why This Fix Works

The fragile part of the old path was the very first bootstrap / collective window:

  • the first reduce kernel could be delayed by lazy loading and enter as a tail kernel
  • this could in turn stall other GPU work, including the observed cudaMemcpyAsync path
  • host metadata buffers could be misclassified as GPU-local memory
  • RPC startup could silently publish duplicated addresses
  • inactive ranks could leave invalid worker-thread state behind

This PR removes those failure points, making PG bootstrap and the first collective much more stable on TENT/NVLink.
cc @caozhanhao @UNIDY2002 @yuechen-sys

Module

  • Transfer Engine (mooncake-transfer-engine)
  • Mooncake Store (mooncake-store)
  • Mooncake EP (mooncake-ep)
  • Integration (mooncake-integration)
  • P2P Store (mooncake-p2p-store)
  • Python Wheel (mooncake-wheel)
  • PyTorch Backend (mooncake-pg)
  • Mooncake RL (mooncake-rl)
  • CI/CD
  • Docs
  • Other

Type of Change

  • Bug fix
  • New feature
  • Refactor
  • Breaking change
  • Documentation update
  • Other

How Has This Been Tested?

Tested with mooncake-wheel/tests/test_mooncake_backend.py.

The CUDA symptom can also be illustrated with:

Expected behavior from the demo:

  • without preloading the tail kernel, the copy may remain blocked until the spin kernel is released
  • after preloading the tail kernel once, the copy completes without waiting for the spin kernel to finish

Checklist

  • I have performed a self-review of my own code.
  • I have formatted my own code using ./scripts/code_format.sh before submitting.
  • I have updated the documentation.
  • I have added tests to prove my changes are effective.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses several critical issues that caused bootstrap and first-collective hangs in Mooncake PG, particularly on TENT with NVLink/MNNVL configurations. The changes enhance the stability and reliability of the system by proactively loading CUDA kernels, correctly classifying host memory regions for transport, improving RPC server initialization, and robustly handling inactive ranks. These fixes collectively eliminate various failure points, leading to a much more stable startup and collective operation.

Highlights

  • Kernel Preloading: Preloaded CUDA reduce kernels during backend initialization to prevent first-collective hangs caused by lazy loading.
  • Memory Location Correction: Registered host-side control, sync, and warmup memory regions with a wildcard location (kWildcardLocation) instead of a GPU-specific location, ensuring correct transport selection on NVLink/MNNVL setups.
  • RPC Port Conflict Handling: Implemented robust error checking for RPC server startup to detect and retry on port conflicts, preventing silent failures and duplicated server name publication.
  • Inactive Rank Management: Initialized task indices for inactive ranks to an explicit invalid value and added checks to safely skip these entries during worker-thread polling, avoiding invalid state access.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors memory registration in ConnectionContext and P2PProxy to consistently use kWildcardLocation for host-side memory and initializes newly allocated arrays to zero. It introduces a preloadReduceKernels mechanism to address CUDA lazy loading for various data types. Additionally, the MooncakeWorker is enhanced with kInvalidTaskId checks for improved task management, and the RPC server now includes robust error handling and retry logic for async_start() to prevent silent failures from port conflicts. The std::string location parameter was removed from the ConnectionContext constructor. I have no feedback to provide as all review comments were filtered out.

Comment on lines +72 to +79
const auto err = server_->get_errc();
if (err) {
LOG(WARNING) << "Failed to start RPC server(async_start) on port " << port
<< ": " << err.message();
delete server_;
server_ = nullptr;
port = 0;
continue;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The added error checking for server_->async_start() is a critical improvement. Previously, RPC port conflicts could lead to silent failures and incorrect reporting of successful startup. This change ensures that such errors are detected, logged, and handled by retrying port selection, directly addressing solution point 3 and significantly enhancing the robustness of the RPC server initialization.

            const auto err = server_->get_errc();
            if (err) {
                LOG(WARNING) << "Failed to start RPC server(async_start) on port " << port
                             << ": " << err.message();
                delete server_;
                server_ = nullptr;
                port = 0;
                continue;
            }

int rc = engine_->registerLocalMemory(resources_.ctrl_send_region_,
kMaxNumRanks * sizeof(P2PControlSlot),
location_);
kWildcardLocation);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Registering resources_.ctrl_send_region_ with kWildcardLocation is a crucial change. This correctly identifies the P2P control regions as host-side memory, preventing misclassification and ensuring proper transport behavior on NVLink/MNNVL setups, as detailed in solution point 2.

Suggested change
kWildcardLocation);
kMaxNumRanks * sizeof(P2PControlSlot),
kWildcardLocation);

rc = engine_->registerLocalMemory(resources_.ctrl_recv_region_,
kMaxNumRanks * sizeof(P2PControlSlot),
location_);
kWildcardLocation);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similarly, registering resources_.ctrl_recv_region_ with kWildcardLocation ensures that this P2P control region is also correctly identified as host-side memory. This consistency is vital for avoiding bootstrap hangs and ensuring reliable P2P communication.

Suggested change
kWildcardLocation);
kMaxNumRanks * sizeof(P2PControlSlot),
kWildcardLocation);

@codecov-commenter
Copy link
Copy Markdown

⚠️ Please install the 'codecov app svg image' to ensure uploads and comments are reliably processed by Codecov.

Codecov Report

✅ All modified and coverable lines are covered by tests.

📢 Thoughts on this report? Let us know!

Copy link
Copy Markdown
Collaborator

@alogfans alogfans left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@UNIDY2002
Copy link
Copy Markdown
Collaborator

Do we need to merge #1733 first?

@KMSorSMS
Copy link
Copy Markdown
Contributor Author

Do we need to merge #1733 first?

Yes, for the TE part to work correctly as TENT, we need to merge #1733 first.

@yuechen-sys yuechen-sys merged commit b9a593c into kvcache-ai:main Mar 31, 2026
16 of 17 checks passed
whn09 pushed a commit to whn09/Mooncake that referenced this pull request Apr 4, 2026
…che-ai#1755)

* [PG]: integrate with tent's nvlink problem

* [PG]: code format

* [tent] code format
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants