Skip to content

Commit f60f94f

Browse files
kainlanclaude
andcommitted
sycl: implement async expert prefetch DMA engine for MoE layer prefetching
Re-implement the ExpertPrefetcher stub methods to enable async H2D DMA of predicted expert weights from host RAM to VRAM. Uses an out-of-order SYCL queue for DMA, separate from the compute queue, with a ring buffer of pre-allocated VRAM slots for zero-allocation-per-hint operation. Key design: - hint() checks placement table, submits async memcpy on OOQ if expert is host-resident (device_ptr==nullptr, host_ptr!=nullptr) - await() waits on per-expert sycl::event and updates placement table device_ptr so the dispatch path routes to GPU instead of CPU - VRAM pool: 8 slots lazily allocated on first hint, sized to expert weight_bytes (~4.3MB each for 120B model = ~35MB total) - gc_completed() recycles pool slots from consumed entries; safe because hint() runs 1+ layers ahead of dispatch (gc runs after GPU consumption) Also includes a minor cleanup of CPU host-resident MUL_MAT dispatch (assert instead of conditional null check for host pointer). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 4d684bb commit f60f94f

3 files changed

Lines changed: 254 additions & 48 deletions

File tree

ggml/src/ggml-sycl/expert-prefetch.cpp

Lines changed: 220 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,16 @@ ExpertPrefetcher::~ExpertPrefetcher() {
2727
if (initialized_ && !ggml_sycl_is_shutting_down()) {
2828
shutdown();
2929
}
30-
// During static destruction, intentionally leak the queue handle.
30+
// Free VRAM pool slots.
31+
if (!ggml_sycl_is_shutting_down() && dma_queue_) {
32+
for (auto & slot : vram_pool_) {
33+
if (slot.ptr) {
34+
sycl::free(slot.ptr, *dma_queue_);
35+
slot.ptr = nullptr;
36+
}
37+
}
38+
}
39+
// During static destruction, intentionally leak the queue handle + VRAM pool.
3140
// The OS reclaims all process memory at exit.
3241
if (ggml_sycl_is_shutting_down() && dma_queue_) {
3342
(void) dma_queue_.release();
@@ -64,26 +73,113 @@ void ExpertPrefetcher::shutdown() {
6473

6574
cancel_all();
6675
initialized_ = false;
67-
GGML_LOG_INFO("[SYCL] Expert prefetcher shut down (completed=%d)\n", completed_count_);
76+
GGML_LOG_INFO("[SYCL] Expert prefetcher shut down (prefetched=%d, already_cached=%d)\n",
77+
completed_count_, prefetch_hits_);
6878
}
6979

7080
// ============================================================================
7181
// Hint: schedule a non-blocking async H2D prefetch on dma_queue_
7282
// ============================================================================
7383

7484
bool ExpertPrefetcher::hint(int layer_idx, int expert_idx) {
75-
// Prefetching disabled after ExpertCache removal.
76-
// Weight management is now handled by unified cache.
77-
(void) layer_idx;
78-
(void) expert_idx;
79-
return false;
85+
if (!initialized_ || !dma_queue_) {
86+
return false;
87+
}
88+
89+
std::lock_guard<std::mutex> lock(mutex_);
90+
91+
expert_key key{ layer_idx, expert_idx };
92+
93+
// Already in-flight or completed — skip.
94+
if (inflight_.count(key)) {
95+
return false;
96+
}
97+
98+
// Check capacity (GC first to free completed slots).
99+
gc_completed();
100+
if (!has_capacity()) {
101+
return false;
102+
}
103+
104+
// Look up expert in placement table.
105+
auto & ptable = get_expert_placement_table();
106+
if (!ptable.is_initialized()) {
107+
return false;
108+
}
109+
110+
auto placement = ptable.get(layer_idx, expert_idx);
111+
112+
// Already in VRAM — nothing to prefetch.
113+
if (placement.device_ptr) {
114+
prefetch_hits_++;
115+
return false;
116+
}
117+
118+
// No host pointer — cannot prefetch.
119+
if (!placement.host_ptr || placement.weight_bytes == 0) {
120+
return false;
121+
}
122+
123+
// Lazily allocate VRAM pool on first use, sized to this expert's weight_bytes.
124+
if (vram_pool_.empty()) {
125+
vram_slot_bytes_ = placement.weight_bytes;
126+
vram_pool_.resize(max_inflight_);
127+
for (auto & slot : vram_pool_) {
128+
try {
129+
slot.ptr = sycl::malloc_device(vram_slot_bytes_, *dma_queue_);
130+
slot.free = (slot.ptr != nullptr);
131+
} catch (const sycl::exception &) {
132+
slot.ptr = nullptr;
133+
slot.free = false;
134+
}
135+
}
136+
size_t allocated = 0;
137+
for (const auto & slot : vram_pool_) {
138+
if (slot.ptr) { allocated++; }
139+
}
140+
GGML_LOG_INFO("[SYCL] Expert prefetch VRAM pool: %zu/%d slots (%.1f MB each)\n",
141+
allocated, max_inflight_, vram_slot_bytes_ / (1024.0 * 1024.0));
142+
}
143+
144+
// Skip if expert is larger than pool slots (model changed mid-run).
145+
if (placement.weight_bytes > vram_slot_bytes_) {
146+
return false;
147+
}
148+
149+
// Acquire a VRAM slot.
150+
int slot = acquire_vram_slot();
151+
if (slot < 0) {
152+
return false;
153+
}
154+
155+
// Submit async H2D DMA on the OOQ.
156+
void * dst = vram_pool_[slot].ptr;
157+
try {
158+
sycl::event ev = dma_queue_->memcpy(dst, placement.host_ptr, placement.weight_bytes);
159+
160+
PrefetchRequest req;
161+
req.key = key;
162+
req.event = ev;
163+
req.device_ptr = dst;
164+
req.pool_slot = slot;
165+
req.completed = false;
166+
inflight_[key] = std::move(req);
167+
168+
GGML_SYCL_DEBUG("[PREFETCH] hint L%d E%d: H2D %.1f KB -> slot %d\n",
169+
layer_idx, expert_idx, placement.weight_bytes / 1024.0, slot);
170+
return true;
171+
} catch (const sycl::exception & e) {
172+
release_vram_slot(slot);
173+
GGML_LOG_WARN("[SYCL] Prefetch H2D failed for L%d E%d: %s\n",
174+
layer_idx, expert_idx, e.what());
175+
return false;
176+
}
80177
}
81178

82179
void ExpertPrefetcher::hint_batch(int layer_idx, const std::vector<int> & expert_indices) {
83-
// Prefetching disabled after ExpertCache removal.
84-
// Weight management is now handled by unified cache.
85-
(void) layer_idx;
86-
(void) expert_indices;
180+
for (int eid : expert_indices) {
181+
hint(layer_idx, eid);
182+
}
87183
}
88184

89185
// ============================================================================
@@ -95,36 +191,96 @@ std::vector<int> ExpertPrefetcher::hint_batch_adaptive(
95191
const std::vector<int> & expert_indices,
96192
int n_miss_total)
97193
{
98-
// Prefetching disabled after ExpertCache removal.
99-
// Weight management is now handled by unified cache.
100-
(void) layer_idx;
101-
(void) expert_indices;
102-
(void) n_miss_total;
103-
104-
// Return empty vector (no experts to CPU dispatch)
105-
return std::vector<int>();
194+
std::vector<int> cpu_dispatch;
195+
196+
// When miss count exceeds capacity, overflow experts go to CPU.
197+
int budget = max_inflight_;
198+
{
199+
std::lock_guard<std::mutex> lock(mutex_);
200+
gc_completed();
201+
budget = max_inflight_ - static_cast<int>(inflight_.size());
202+
}
203+
204+
int scheduled = 0;
205+
for (int eid : expert_indices) {
206+
if (scheduled < budget && n_miss_total <= max_inflight_) {
207+
hint(layer_idx, eid);
208+
scheduled++;
209+
} else {
210+
cpu_dispatch.push_back(eid);
211+
}
212+
}
213+
214+
return cpu_dispatch;
106215
}
107216

108217
// ============================================================================
109218
// Await: block until a specific expert's DMA completes, return VRAM ptr
110219
// ============================================================================
111220

112221
void * ExpertPrefetcher::await(int layer_idx, int expert_idx) {
113-
// Prefetching disabled after ExpertCache removal.
114-
// Weight management is now handled by unified cache.
115-
// Return nullptr to indicate no prefetch available.
116-
(void) layer_idx;
117-
(void) expert_idx;
118-
return nullptr;
222+
if (!initialized_) {
223+
return nullptr;
224+
}
225+
226+
std::lock_guard<std::mutex> lock(mutex_);
227+
228+
expert_key key{ layer_idx, expert_idx };
229+
auto it = inflight_.find(key);
230+
if (it == inflight_.end()) {
231+
// No in-flight prefetch for this expert.
232+
return nullptr;
233+
}
234+
235+
auto & req = it->second;
236+
if (!req.completed) {
237+
// Wait for the specific DMA event to complete.
238+
try {
239+
req.event.wait();
240+
} catch (const sycl::exception & e) {
241+
GGML_LOG_WARN("[SYCL] Prefetch await failed for L%d E%d: %s\n",
242+
layer_idx, expert_idx, e.what());
243+
release_vram_slot(req.pool_slot);
244+
inflight_.erase(it);
245+
return nullptr;
246+
}
247+
req.completed = true;
248+
completed_count_++;
249+
250+
// Update placement table so the dispatch path finds device_ptr.
251+
auto & ptable = get_expert_placement_table();
252+
if (ptable.is_initialized()) {
253+
ptable.set_device_ptr(layer_idx, expert_idx, 0, req.device_ptr);
254+
}
255+
256+
GGML_SYCL_DEBUG("[PREFETCH] await L%d E%d: DMA complete, device_ptr=%p\n",
257+
layer_idx, expert_idx, req.device_ptr);
258+
}
259+
260+
return req.device_ptr;
119261
}
120262

121263
// ============================================================================
122264
// Cancel: drain all in-flight prefetches
123265
// ============================================================================
124266

125267
void ExpertPrefetcher::cancel_all() {
126-
// Prefetching disabled after ExpertCache removal.
127-
// No in-flight operations to cancel.
268+
std::lock_guard<std::mutex> lock(mutex_);
269+
270+
// Wait for all in-flight DMAs.
271+
if (dma_queue_) {
272+
try {
273+
dma_queue_->wait();
274+
} catch (const sycl::exception &) {
275+
// Best effort during shutdown.
276+
}
277+
}
278+
279+
// Release all pool slots and clear tracking.
280+
for (auto & [key, req] : inflight_) {
281+
release_vram_slot(req.pool_slot);
282+
}
283+
inflight_.clear();
128284
}
129285

130286
// ============================================================================
@@ -153,10 +309,28 @@ int ExpertPrefetcher::completed_count() const {
153309

154310
void ExpertPrefetcher::gc_completed() {
155311
// Called with mutex_ held.
156-
// Remove entries that have been completed and consumed by await().
312+
// Remove completed tracking entries and release their VRAM pool slots.
313+
//
314+
// Safety: gc_completed() is called from hint(), which runs for future
315+
// layers (L+1..L+depth). By the time hint(L+1) runs, layer L-1's GPU
316+
// dispatch has completed (ggml_sycl_mul_mat_id runs synchronously, and
317+
// dispatch_cpu_and_scatter includes stream->wait()). So completed
318+
// entries from layer L-1 are safe to gc because the GPU has consumed
319+
// their pool slot data.
320+
//
321+
// Note: entries become completed in await() for layer L, and are gc'd
322+
// by hint() for layer L+1 or later. Since hint() runs 1+ layers ahead,
323+
// there's at least one full dispatch cycle between completed and gc.
157324
auto it = inflight_.begin();
158325
while (it != inflight_.end()) {
159326
if (it->second.completed) {
327+
// Clear placement table device_ptr since the pool slot
328+
// will be recycled for a different expert.
329+
auto & ptable = get_expert_placement_table();
330+
if (ptable.is_initialized()) {
331+
ptable.set_device_ptr(it->second.key.layer, it->second.key.expert_id, 0, nullptr);
332+
}
333+
release_vram_slot(it->second.pool_slot);
160334
it = inflight_.erase(it);
161335
} else {
162336
++it;
@@ -169,6 +343,24 @@ bool ExpertPrefetcher::has_capacity() const {
169343
return static_cast<int>(inflight_.size()) < max_inflight_;
170344
}
171345

346+
int ExpertPrefetcher::acquire_vram_slot() {
347+
// Called with mutex_ held.
348+
for (int i = 0; i < static_cast<int>(vram_pool_.size()); i++) {
349+
if (vram_pool_[i].free && vram_pool_[i].ptr) {
350+
vram_pool_[i].free = false;
351+
return i;
352+
}
353+
}
354+
return -1;
355+
}
356+
357+
void ExpertPrefetcher::release_vram_slot(int slot) {
358+
// Called with mutex_ held.
359+
if (slot >= 0 && slot < static_cast<int>(vram_pool_.size())) {
360+
vram_pool_[slot].free = true;
361+
}
362+
}
363+
172364
// ============================================================================
173365
// ExpertPredictor: pre-attention expert prediction
174366
// ============================================================================

ggml/src/ggml-sycl/expert-prefetch.hpp

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@ namespace ggml_sycl {
3737
struct PrefetchRequest {
3838
expert_key key;
3939
sycl::event event; // DMA completion event from dma_queue_
40-
bool completed = false;
40+
void * device_ptr = nullptr; // VRAM destination of the H2D DMA
41+
int pool_slot = -1; // Index into vram_pool_ (-1 = no slot)
42+
bool completed = false;
4143
};
4244

4345
// Async DMA engine for prefetching MoE expert weights from host RAM to VRAM.
@@ -126,10 +128,26 @@ class ExpertPrefetcher {
126128
// In-flight prefetch tracking. Key = expert_key.
127129
std::unordered_map<expert_key, PrefetchRequest, expert_key_hash> inflight_;
128130

131+
// VRAM prefetch pool: ring buffer of pre-allocated device memory slots.
132+
// Each slot holds one expert's worth of weight data.
133+
// Allocated lazily on first hint() with weight_bytes > 0.
134+
struct vram_slot {
135+
void * ptr = nullptr;
136+
bool free = true;
137+
};
138+
std::vector<vram_slot> vram_pool_;
139+
size_t vram_slot_bytes_ = 0; // Size of each pool slot
140+
141+
// Acquire a free VRAM slot. Returns slot index or -1 if none available.
142+
int acquire_vram_slot();
143+
// Release a VRAM slot back to the pool.
144+
void release_vram_slot(int slot);
145+
129146
mutable std::mutex mutex_;
130147

131148
// Stats
132149
int completed_count_ = 0;
150+
int prefetch_hits_ = 0; // Experts found already in VRAM (no DMA needed)
133151

134152
// Garbage-collect completed requests to free tracking slots.
135153
void gc_completed();

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -21997,7 +21997,8 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx,
2199721997
const size_t src1_bytes = static_cast<size_t>(K) * sizeof(float);
2199821998
const size_t dst_bytes = static_cast<size_t>(N) * sizeof(float);
2199921999

22000-
// Check that CPU vec_dot is available for this type
22000+
// Early-out before D2H memcpy — vec_dot_rows checks internally too,
22001+
// but we avoid the unnecessary src1 copy when vec_dot is unavailable.
2200122002
const auto * cpu_traits = ggml_get_type_traits_cpu(src0->type);
2200222003
if (cpu_traits && cpu_traits->vec_dot) {
2200322004
// Thread-local staging buffers for D2H(src1) and H2D(dst)
@@ -22011,27 +22012,22 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx,
2201122012
// D2H: copy activations from GPU to host
2201222013
stream->memcpy(tl_src1_host.data(), src1->data, src1_bytes).wait();
2201322014

22014-
// Get host pointer to weight data (already host-accessible)
22015-
const void * src0_host = src0->data;
22016-
if (!src0_host) {
22017-
// Fallback: try get_data_ptr which may return a host pointer
22018-
src0_host = ggml_sycl_get_data_ptr(src0, ctx.device);
22019-
}
22015+
// Host-resident weights always have a valid host pointer
22016+
GGML_ASSERT(src0->data && "host-resident weight must have valid host pointer");
22017+
const void * weight_host = src0->data;
2202022018

22021-
if (src0_host) {
22022-
GGML_SYCL_DEBUG("[CPU-HOST-MAT] type=%d K=%lld N=%lld tensor=%s\n",
22023-
(int) src0->type, (long long) K, (long long) N,
22024-
src0->name ? src0->name : "?");
22019+
GGML_SYCL_DEBUG("[CPU-HOST-MAT] type=%d K=%lld N=%lld tensor=%s\n",
22020+
(int) src0->type, (long long) K, (long long) N,
22021+
src0->name ? src0->name : "?");
2202522022

22026-
// CPU vec_dot for all output rows
22027-
ggml_sycl_cpu_vec_dot_rows(src0->type, static_cast<int>(K),
22028-
src0_host, tl_src1_host.data(),
22029-
tl_dst_host.data(), static_cast<int>(N));
22023+
// CPU vec_dot for all output rows
22024+
ggml_sycl_cpu_vec_dot_rows(src0->type, static_cast<int>(K),
22025+
weight_host, tl_src1_host.data(),
22026+
tl_dst_host.data(), static_cast<int>(N));
2203022027

22031-
// H2D: copy result back to GPU
22032-
stream->memcpy(dst->data, tl_dst_host.data(), dst_bytes).wait();
22033-
return;
22034-
}
22028+
// H2D: copy result back to GPU
22029+
stream->memcpy(dst->data, tl_dst_host.data(), dst_bytes).wait();
22030+
return;
2203522031
}
2203622032
}
2203722033
}

0 commit comments

Comments
 (0)