Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mooncake-pg/include/connection_poller.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class ConnectionContext {

public:
ConnectionContext(int backendIndex, int rank, int size, bool isDummy,
uint64_t* local2global_rank_map, std::string location,
uint64_t* local2global_rank_map,
c10::intrusive_ptr<::c10d::Store> store,
std::shared_ptr<TransferGroupMeta> meta,
std::shared_ptr<P2PProxy> p2p_proxy,
Expand Down
1 change: 1 addition & 0 deletions mooncake-pg/include/mooncake_worker.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ void launchReduceKernel(at::Tensor dst, size_t pos, size_t realSize, void* src,

void launchReduceCpu(at::Tensor dst, size_t pos, size_t realSize, void* src,
size_t numRanks, c10d::ReduceOp op, bool* activeRanks);
void preloadReduceKernels();

class ConnectionContext;
class MooncakeWorker {
Expand Down
10 changes: 5 additions & 5 deletions mooncake-pg/src/connection_poller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <algorithm>
#include <cstring>
#include <limits>
#include "memory_location.h"
#include "mooncake_worker.cuh"

namespace mooncake {
Expand Down Expand Up @@ -42,7 +43,6 @@ static bool supportFabricMem() {
ConnectionContext::ConnectionContext(int backendIndex, int rank, int size,
bool isDummy,
uint64_t* local2global_rank_map,
std::string location,
c10::intrusive_ptr<::c10d::Store> store,
std::shared_ptr<TransferGroupMeta> meta,
std::shared_ptr<P2PProxy> p2p_proxy,
Expand All @@ -67,15 +67,15 @@ ConnectionContext::ConnectionContext(int backendIndex, int rank, int size,
return;
}

warmup_send_region_ = new int32_t[kMaxNumRanks];
warmup_send_region_ = new int32_t[kMaxNumRanks]{};
warmup_send_region_[0] = 1;
int rc = engine_->registerLocalMemory(
warmup_send_region_, kMaxNumRanks * sizeof(int32_t), location);
warmup_send_region_, kMaxNumRanks * sizeof(int32_t), kWildcardLocation);
TORCH_CHECK(!rc, "Failed to register local memory for context.");

warmup_recv_region_ = new int32_t[kMaxNumRanks]{};
rc = engine_->registerLocalMemory(warmup_recv_region_,
kMaxNumRanks * sizeof(int32_t), location);
rc = engine_->registerLocalMemory(
warmup_recv_region_, kMaxNumRanks * sizeof(int32_t), kWildcardLocation);
TORCH_CHECK(!rc, "Failed to register local memory for context.");
}

Expand Down
20 changes: 13 additions & 7 deletions mooncake-pg/src/mooncake_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <atomic>
#include <memory>
#include "connection_poller.h"
#include "memory_location.h"
#include "mooncake_worker.cuh"

namespace mooncake {
Expand Down Expand Up @@ -177,16 +178,18 @@ MooncakeBackend::MooncakeBackend(
TORCH_CHECK(static_cast<size_t>(size) <= kMaxNumRanks,
"The number of ranks exceeds the limit.");
for (size_t i = 0; i < 2; i++) {
cpu_sync_send_region_[i] = new int32_t[kMaxNumRanks];
int rc = engine_->registerLocalMemory(
cpu_sync_send_region_[i], kMaxNumRanks * sizeof(int32_t), location);
cpu_sync_send_region_[i] = new int32_t[kMaxNumRanks]{};
int rc = engine_->registerLocalMemory(cpu_sync_send_region_[i],
kMaxNumRanks * sizeof(int32_t),
kWildcardLocation);
TORCH_CHECK(!rc, REGISTER_BUFFER_ERROR_MSG);
}

for (size_t i = 0; i < 2; i++) {
cpu_sync_recv_region_[i] = new int32_t[kMaxNumRanks];
int rc = engine_->registerLocalMemory(
cpu_sync_recv_region_[i], kMaxNumRanks * sizeof(int32_t), location);
cpu_sync_recv_region_[i] = new int32_t[kMaxNumRanks]{};
int rc = engine_->registerLocalMemory(cpu_sync_recv_region_[i],
kMaxNumRanks * sizeof(int32_t),
kWildcardLocation);
TORCH_CHECK(!rc, REGISTER_BUFFER_ERROR_MSG);
}

Expand All @@ -203,6 +206,9 @@ MooncakeBackend::MooncakeBackend(
worker_ = worker_mgr.GetCPUWorker();
else
worker_ = worker_mgr.GetCUDAWorker(cuda_device_index);
if (!isCpu_) {
preloadReduceKernels();
}
worker_->Start();

p2p_proxy_ = std::make_shared<P2PProxy>(
Expand All @@ -218,7 +224,7 @@ MooncakeBackend::MooncakeBackend(
meta_ = std::make_shared<TransferGroupMeta>();
connection_ctx_ = std::make_shared<ConnectionContext>(
backendIndex_, rank, size, options_ && options_->isExtension_,
local2global_rank_map_, location, store, meta_, p2p_proxy_, engine_);
local2global_rank_map_, store, meta_, p2p_proxy_, engine_);

rank_info.send_buffer[0] = (uint64_t)send_buffer_[0];
rank_info.send_buffer[1] = (uint64_t)send_buffer_[1];
Expand Down
27 changes: 26 additions & 1 deletion mooncake-pg/src/mooncake_worker.cu
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,19 @@ __global__ void reduceKernel(scalar_t* dst, const scalar_t* src,
}
}

namespace {

template <typename scalar_t>
void preload_reduce_kernel(const char* name) {
cudaFuncAttributes attr{};
auto err = cudaFuncGetAttributes(
&attr, reinterpret_cast<const void*>(reduceKernel<scalar_t>));
TORCH_CHECK(err == cudaSuccess, "Failed to preload kernel ", name, ": ",
cudaGetErrorString(err));
}

} // namespace

void launchReduceKernel(at::Tensor dst, size_t pos, size_t realSize, void* src,
size_t numRanks, c10d::ReduceOp op, bool* activeRanks,
cudaStream_t stream) {
Expand Down Expand Up @@ -244,6 +257,18 @@ void launchReduceCpu(at::Tensor dst, size_t pos, size_t realSize, void* src,
}
}

void preloadReduceKernels() {
preload_reduce_kernel<uint8_t>("reduceKernel<uint8_t>");
preload_reduce_kernel<int8_t>("reduceKernel<int8_t>");
preload_reduce_kernel<int16_t>("reduceKernel<int16_t>");
preload_reduce_kernel<int>("reduceKernel<int>");
preload_reduce_kernel<int64_t>("reduceKernel<int64_t>");
preload_reduce_kernel<float>("reduceKernel<float>");
preload_reduce_kernel<double>("reduceKernel<double>");
preload_reduce_kernel<bool>("reduceKernel<bool>");
preload_reduce_kernel<at::BFloat16>("reduceKernel<BFloat16>");
Comment thread
KMSorSMS marked this conversation as resolved.
}

MooncakeWorker::MooncakeWorker(int cuda_device_index)
: cuda_device_index_(cuda_device_index) {
int deviceCount = 0;
Expand Down Expand Up @@ -378,4 +403,4 @@ c10::intrusive_ptr<c10d::Work> MooncakeWorker::putTaskCuda(
return c10::make_intrusive<MooncakeWorkCuda>(opType, event, meta);
}

} // namespace mooncake
} // namespace mooncake
15 changes: 15 additions & 0 deletions mooncake-pg/src/mooncake_worker_thread.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <cuda_runtime.h>
#include <thread>
#include <mooncake_worker.cuh>
#include <glog/logging.h>
#include <transfer_engine.h>

namespace mooncake {
Expand All @@ -12,6 +13,8 @@ enum WorkerTaskStatus {
DONE = 3,
};

static constexpr size_t kInvalidTaskId = static_cast<size_t>(-1);

void MooncakeWorker::Start() {
bool expected = false;
if (started_.compare_exchange_strong(expected, true)) {
Expand Down Expand Up @@ -51,6 +54,9 @@ void MooncakeWorker::startWorker() {
std::memory_order_release);
continue;
}
for (size_t j = 0; j < kMaxNumRanks; ++j) {
rankToTaskId[i][j] = kInvalidTaskId;
}
std::vector<TransferRequest> entries;
for (int j = 0; j < group->size; ++j) {
if (!group->activeRanks[j]) {
Expand Down Expand Up @@ -131,6 +137,9 @@ void MooncakeWorker::startWorker() {
if (!group->activeRanks[j]) {
continue;
}
if (rankToTaskId[i][j] == kInvalidTaskId) {
continue;
}
group->engine->getTransferStatus(
task.batchID, rankToTaskId[i][j], status);
if (status.s != TransferStatusEnum::COMPLETED) {
Expand Down Expand Up @@ -175,6 +184,9 @@ void MooncakeWorker::startWorker() {
auto source_ptr = (int32_t*)group->segmentInfos[group->rank]
.send_sync[task.bufferOffset];

for (size_t j = 0; j < kMaxNumRanks; ++j) {
rankToTaskId[i][j] = kInvalidTaskId;
}
std::vector<TransferRequest> entries;
for (int j = 0; j < group->size; ++j) {
if (!group->activeRanks[j]) {
Expand Down Expand Up @@ -213,6 +225,9 @@ void MooncakeWorker::startWorker() {
if (!group->activeRanks[j]) {
continue;
}
if (rankToTaskId[i][j] == kInvalidTaskId) {
continue;
}
group->engine->getTransferStatus(
task.batchID, rankToTaskId[i][j], status);
if (signal_ptr[j] != 1 ||
Expand Down
6 changes: 3 additions & 3 deletions mooncake-pg/src/p2p_proxy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <cstring>
#include <limits>
#include <thread>
#include "memory_location.h"

namespace mooncake {

Expand Down Expand Up @@ -161,15 +162,14 @@ void P2PProxy::AllocateResources() {
}
}
}

int rc = engine_->registerLocalMemory(resources_.ctrl_send_region_,
kMaxNumRanks * sizeof(P2PControlSlot),
location_);
kWildcardLocation);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Registering resources_.ctrl_send_region_ with kWildcardLocation is a crucial change. This correctly identifies the P2P control regions as host-side memory, preventing misclassification and ensuring proper transport behavior on NVLink/MNNVL setups, as detailed in solution point 2.

Suggested change
kWildcardLocation);
kMaxNumRanks * sizeof(P2PControlSlot),
kWildcardLocation);

TORCH_CHECK(rc == 0, "Failed to register P2P ctrl send region");

rc = engine_->registerLocalMemory(resources_.ctrl_recv_region_,
kMaxNumRanks * sizeof(P2PControlSlot),
location_);
kWildcardLocation);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similarly, registering resources_.ctrl_recv_region_ with kWildcardLocation ensures that this P2P control region is also correctly identified as host-side memory. This consistency is vital for avoiding bootstrap hangs and ensuring reliable P2P communication.

Suggested change
kWildcardLocation);
kMaxNumRanks * sizeof(P2PControlSlot),
kWildcardLocation);

TORCH_CHECK(rc == 0, "Failed to register P2P ctrl recv region");
}

Expand Down
10 changes: 10 additions & 0 deletions mooncake-transfer-engine/tent/src/rpc/rpc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,16 @@ Status CoroRpcAgent::start(uint16_t& port, bool ipv6) {
server_ = new coro_rpc::coro_rpc_server(kRpcThreads, port);
server_->register_handler<&CoroRpcAgent::process>(this);
server_->async_start();
const auto err = server_->get_errc();
if (err) {
LOG(WARNING)
<< "Failed to start RPC server(async_start) on port "
<< port << ": " << err.message();
delete server_;
server_ = nullptr;
port = 0;
continue;
Comment on lines +72 to +80
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The added error checking for server_->async_start() is a critical improvement. Previously, RPC port conflicts could lead to silent failures and incorrect reporting of successful startup. This change ensures that such errors are detected, logged, and handled by retrying port selection, directly addressing solution point 3 and significantly enhancing the robustness of the RPC server initialization.

            const auto err = server_->get_errc();
            if (err) {
                LOG(WARNING) << "Failed to start RPC server(async_start) on port " << port
                             << ": " << err.message();
                delete server_;
                server_ = nullptr;
                port = 0;
                continue;
            }

}
running_ = true;
return Status::OK();
} catch (const std::exception& e) {
Expand Down
Loading