diff --git a/python/sglang/srt/layers/flashinfer_comm_fusion.py b/python/sglang/srt/layers/flashinfer_comm_fusion.py index f85aff3ccc24..19efc8d44b20 100644 --- a/python/sglang/srt/layers/flashinfer_comm_fusion.py +++ b/python/sglang/srt/layers/flashinfer_comm_fusion.py @@ -96,7 +96,7 @@ def cleanup(self): def ensure_workspace_initialized( - max_token_num: int = 2048, hidden_dim: int = 4096, use_fp32_lamport: bool = False + max_token_num: int = 16384, hidden_dim: int = 4096, use_fp32_lamport: bool = False ): """Ensure workspace is initialized""" if not is_flashinfer_available() or _flashinfer_comm is None: @@ -128,7 +128,7 @@ def flashinfer_allreduce_residual_rmsnorm( residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6, - max_token_num: int = 2048, + max_token_num: int = 16384, use_oneshot: Optional[bool] = None, trigger_completion_at_end: bool = False, fp32_acc: bool = False,