Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions python/sglang/srt/eplb/expert_location.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,9 +284,17 @@ def update(
# -------------------------------- usage ------------------------------------

def logical_to_all_physical(
self, layer_id: int, logical_expert_id: int
self,
layer_id: int,
logical_expert_id: int,
require_global_experts: bool = False,
) -> List[int]:
# Use CPU copy to avoid GPU→CPU sync on every call, which is expensive in update weights scenario
if require_global_experts:
num_physical_experts = self.logical_to_all_physical_map_cpu[layer_id].shape[
-1
]
return list(torch.arange(0, num_physical_experts))
return [
physical_expert_id
for physical_expert_id in self.logical_to_all_physical_map_cpu[
Expand Down Expand Up @@ -355,14 +363,10 @@ def _compute_logical_to_all_physical_map(
)

# Replace by the nearest physical expert
mapped_physical_experts = logical_to_all_physical_map[layer_id][
logical_expert_id
]
if (
nearest_expert != -1
and nearest_expert not in mapped_physical_experts
):
mapped_physical_experts[0] = nearest_expert
if nearest_expert != -1:
logical_to_all_physical_map[layer_id][logical_expert_id] = [
nearest_expert
]

logical_to_all_physical_map = _pad_nested_array(
logical_to_all_physical_map, pad_value=-1
Expand Down
5 changes: 4 additions & 1 deletion python/sglang/srt/layers/moe/fused_moe_triton/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,9 +517,12 @@ def weight_loader(
# This is a shared expert.
physical_expert_ids = [expert_id]
else:
require_global_experts = getattr(
param, "_sglang_require_global_experts", False
)
physical_expert_ids = (
global_expert_location_metadata.logical_to_all_physical(
self.layer_id, expert_id
self.layer_id, expert_id, require_global_experts
)
)

Expand Down
Loading