From 4ce872e4839d07e7e213d07f75f59839a0614509 Mon Sep 17 00:00:00 2001 From: "luoyuan.luo" Date: Wed, 19 Nov 2025 21:10:35 +0800 Subject: [PATCH 1/2] Replace torch.repeat_interleave with faster np.repeat for qwen-vl --- python/sglang/srt/models/qwen2_vl.py | 6 +- python/sglang/srt/models/qwen3_vl.py | 11 +- python/sglang/srt/models/utils.py | 23 ++++ test/srt/ops/test_repeat_interleave.py | 153 +++++++++++++++++++++++++ test/srt/run_suite.py | 1 + 5 files changed, 181 insertions(+), 13 deletions(-) create mode 100644 test/srt/ops/test_repeat_interleave.py diff --git a/python/sglang/srt/models/qwen2_vl.py b/python/sglang/srt/models/qwen2_vl.py index 84c15ee776b..4518d087971 100644 --- a/python/sglang/srt/models/qwen2_vl.py +++ b/python/sglang/srt/models/qwen2_vl.py @@ -44,6 +44,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.qwen2 import Qwen2Model +from sglang.srt.models.utils import compute_cu_seqlens_from_grid_numpy from sglang.srt.utils import add_prefix from sglang.srt.utils.hf_transformers_utils import get_processor @@ -387,10 +388,7 @@ def forward( emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) # compute cu_seqlens - cu_seqlens = torch.repeat_interleave( - grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] - ).cumsum(dim=0, dtype=torch.int32) - cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens]) + cu_seqlens = compute_cu_seqlens_from_grid_numpy(grid_thw) # transformers x = x.unsqueeze(1) diff --git a/python/sglang/srt/models/qwen3_vl.py b/python/sglang/srt/models/qwen3_vl.py index c4d9456bc9e..7026777129a 100644 --- a/python/sglang/srt/models/qwen3_vl.py +++ b/python/sglang/srt/models/qwen3_vl.py @@ -46,6 +46,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.qwen3 import Qwen3Model +from sglang.srt.models.utils import compute_cu_seqlens_from_grid_numpy from sglang.srt.utils import add_prefix from sglang.srt.utils.hf_transformers_utils import get_processor @@ -434,15 +435,7 @@ def forward( position_embeddings = (emb.cos(), emb.sin()) # compute cu_seqlens - cu_seqlens = torch.repeat_interleave( - grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] - ).cumsum(dim=0) - cu_seqlens = torch.cat( - [ - torch.zeros(1, dtype=torch.int32, device=cu_seqlens.device), - cu_seqlens.to(torch.int32), - ] - ) + cu_seqlens = compute_cu_seqlens_from_grid_numpy(grid_thw) x = x.unsqueeze(1) diff --git a/python/sglang/srt/models/utils.py b/python/sglang/srt/models/utils.py index 8100e673474..15c50e8a7da 100644 --- a/python/sglang/srt/models/utils.py +++ b/python/sglang/srt/models/utils.py @@ -12,6 +12,7 @@ # limitations under the License. # ============================================================================== +import numpy as np import torch from sglang.srt.layers.radix_attention import RadixAttention @@ -59,3 +60,25 @@ def permute_inv(perm: torch.Tensor) -> torch.Tensor: inv_perm = torch.empty_like(perm) inv_perm[perm] = torch.arange(perm.numel(), device=perm.device, dtype=perm.dtype) return inv_perm + + +def compute_cu_seqlens_from_grid_numpy(grid_thw: torch.Tensor) -> torch.Tensor: + """ + Compute cu_seqlens from grid_thw using NumPy. + + grid_thw: [T, 3] int tensor on CPU. + columns: [repeat_count, H, W] + Returns: + cu_seqlens: 1D int32 tensor on CPU, shape [N + 1] + """ + assert ( + grid_thw.device.type == "cpu" + ), "compute_cu_seqlens_from_grid_numpy expects a CPU tensor" + arr = grid_thw.numpy() + + cu_seqlens = np.repeat(arr[:, 1] * arr[:, 2], arr[:, 0]).cumsum( + axis=0, dtype=np.int32 + ) + cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens]) + cu_seqlens = torch.from_numpy(cu_seqlens) + return cu_seqlens diff --git a/test/srt/ops/test_repeat_interleave.py b/test/srt/ops/test_repeat_interleave.py new file mode 100644 index 00000000000..bef791227a8 --- /dev/null +++ b/test/srt/ops/test_repeat_interleave.py @@ -0,0 +1,153 @@ +import time + +import numpy as np +import torch + +from sglang.srt.models.utils import compute_cu_seqlens_from_grid_numpy as cpu_numpy_impl + + +def torch_ref_impl(grid_thw: torch.Tensor) -> torch.Tensor: + """ + Pure PyTorch implementation of cu_seqlens computation. + Assumes grid_thw is already on the correct device (CPU here). + Shape: [T, 3], columns: [repeat_count, H, W] + """ + cu_seqlens = torch.repeat_interleave( + grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] + ).cumsum(dim=0) + cu_seqlens = torch.cat( + [ + torch.zeros(1, dtype=torch.int32, device=cu_seqlens.device), + cu_seqlens.to(torch.int32), + ] + ) + return cu_seqlens + + +def benchmark_once(fn, grid_thw, iters: int = 1000): + """ + Run a function `fn` on the same input `grid_thw` for `iters` times + and measure total elapsed time. + """ + start = time.perf_counter() + for _ in range(iters): + out = fn(grid_thw) + end = time.perf_counter() + return (end - start), out + + +def check_correctness_cpu(): + """ + Perform multiple CPU-side correctness checks: + - Different sizes of grid_thw + - Different ranges of repeat counts + - Check that inputs are not modified + - Check shape, dtype, and values are exactly the same + between torch_ref_impl and numpy_impl_cpu + """ + torch.manual_seed(0) + np.random.seed(0) + + # (T, repeat_min, repeat_max) + test_configs = [ + (16, 1, 4), # small T, small repeat counts + (128, 0, 4), # allow repeat=0 to test edge cases + (512, 1, 8), + (1024, 1, 16), + ] + + num_cases_per_config = 10 + + for T, repeat_min, repeat_max in test_configs: + for _ in range(num_cases_per_config): + # grid_thw: [T, 3] + # col0: repeat count + # col1, col2: arbitrary positive integers (here 1..16) + repeats = torch.randint( + repeat_min, repeat_max + 1, (T, 1), dtype=torch.int32 + ) + th = torch.randint(1, 17, (T, 1), dtype=torch.int32) + tw = torch.randint(1, 17, (T, 1), dtype=torch.int32) + grid_thw = torch.cat([repeats, th, tw], dim=1) + + # Save a copy to ensure functions do not modify inputs + grid_clone = grid_thw.clone() + + out_torch = torch_ref_impl(grid_thw) + out_numpy = cpu_numpy_impl(grid_thw) + + # Input should not be modified + assert torch.equal( + grid_thw, grid_clone + ), "Function modified input grid_thw!" + + # Shapes must be the same + assert ( + out_torch.shape == out_numpy.shape + ), f"Shape mismatch: torch={out_torch.shape}, numpy={out_numpy.shape}" + + # Dtypes must be the same (should both be int32) + assert ( + out_torch.dtype == out_numpy.dtype == torch.int32 + ), f"dtype mismatch: torch={out_torch.dtype}, numpy={out_numpy.dtype}" + + # Values must be exactly the same + if not torch.equal(out_torch.cpu(), out_numpy.cpu()): + diff_idx = (out_torch.cpu() != out_numpy.cpu()).nonzero(as_tuple=False) + idx0 = diff_idx[0].item() + raise AssertionError( + f"Value mismatch, T={T}, first differing index={idx0}, " + f"torch={out_torch[idx0].item()}, " + f"numpy={out_numpy[idx0].item()}" + ) + + print("CPU correctness check: PASSED.") + + +def main(): + # Setting number of threads to reduce noise from thread scheduling; + # you can comment this out if you prefer default behavior. + torch.set_num_threads(1) + + # --------------- Correctness check --------------- + check_correctness_cpu() + print("\nAll correctness checks passed. Starting benchmark...\n") + + # --------------- Performance benchmark --------------- + # Typical scales: + # T = number of rows in grid_thw + # H, W only participate in multiplication + configs = [ + (128, 8, 8), + (512, 8, 8), + (2048, 8, 8), + (8192, 8, 8), + ] + + iters = 2000 # number of iterations per configuration + + print("=== CPU benchmark ===") + for T, H, W in configs: + # Construct grid_thw: [T, 3] + # col0: repeat count + # col1, col2: multiplicative factors + grid_thw = torch.randint(1, 5, (T, 3), dtype=torch.int32) + grid_thw[:, 1] = H + grid_thw[:, 2] = W + + t_torch, out_torch = benchmark_once(torch_ref_impl, grid_thw, iters=iters) + t_numpy, out_numpy = benchmark_once(cpu_numpy_impl, grid_thw, iters=iters) + + # Additional safety check: results should match + same = torch.equal(out_torch.cpu(), out_numpy.cpu()) + + print( + f"[CPU] T={T:5d}, iters={iters:4d} | " + f"torch={t_torch*1e3:7.2f} ms, " + f"numpy={t_numpy*1e3:7.2f} ms, " + f"same={same}" + ) + + +if __name__ == "__main__": + main() diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 48510c959ee..32adcd03f7d 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -46,6 +46,7 @@ TestFile("openai_server/validation/test_matched_stop.py", 60), TestFile("openai_server/validation/test_openai_server_ignore_eos.py", 85), TestFile("openai_server/validation/test_request_length_validation.py", 31), + TestFile("ops/test_repeat_interleave.py", 60), TestFile("quant/test_block_int8.py", 22), TestFile("quant/test_fp8_kernel.py", 8), TestFile("quant/test_int8_kernel.py", 8), From d6143370820c40b37975100a74b93f6393c47cf9 Mon Sep 17 00:00:00 2001 From: "luoyuan.luo" Date: Sat, 22 Nov 2025 21:04:20 +0800 Subject: [PATCH 2/2] Address review comments --- test/srt/ops/test_repeat_interleave.py | 206 ++++++++++++------------- 1 file changed, 97 insertions(+), 109 deletions(-) diff --git a/test/srt/ops/test_repeat_interleave.py b/test/srt/ops/test_repeat_interleave.py index bef791227a8..9aa0a859d17 100644 --- a/test/srt/ops/test_repeat_interleave.py +++ b/test/srt/ops/test_repeat_interleave.py @@ -1,6 +1,8 @@ import time +from typing import Tuple import numpy as np +import pytest import torch from sglang.srt.models.utils import compute_cu_seqlens_from_grid_numpy as cpu_numpy_impl @@ -36,118 +38,104 @@ def benchmark_once(fn, grid_thw, iters: int = 1000): return (end - start), out -def check_correctness_cpu(): +# (T, repeat_min, repeat_max) +GRID_TEST_CONFIGS: list[Tuple[int, int, int]] = [ + (16, 1, 4), # small T, small repeat counts + (128, 0, 4), # allow repeat=0 to test edge cases + (512, 1, 8), + (1024, 1, 16), +] + +NUM_CASES_PER_CONFIG = 10 + + +def _generate_random_grid(T: int, repeat_min: int, repeat_max: int) -> torch.Tensor: """ - Perform multiple CPU-side correctness checks: - - Different sizes of grid_thw - - Different ranges of repeat counts - - Check that inputs are not modified - - Check shape, dtype, and values are exactly the same - between torch_ref_impl and numpy_impl_cpu + grid_thw: [T, 3] + col0: repeat count + col1, col2: arbitrary positive integers (here 1..16) """ - torch.manual_seed(0) - np.random.seed(0) - - # (T, repeat_min, repeat_max) - test_configs = [ - (16, 1, 4), # small T, small repeat counts - (128, 0, 4), # allow repeat=0 to test edge cases - (512, 1, 8), - (1024, 1, 16), - ] - - num_cases_per_config = 10 - - for T, repeat_min, repeat_max in test_configs: - for _ in range(num_cases_per_config): - # grid_thw: [T, 3] - # col0: repeat count - # col1, col2: arbitrary positive integers (here 1..16) - repeats = torch.randint( - repeat_min, repeat_max + 1, (T, 1), dtype=torch.int32 + repeats = torch.randint(repeat_min, repeat_max + 1, (T, 1), dtype=torch.int32) + th = torch.randint(1, 17, (T, 1), dtype=torch.int32) + tw = torch.randint(1, 17, (T, 1), dtype=torch.int32) + grid_thw = torch.cat([repeats, th, tw], dim=1) + return grid_thw + + +class TestRepeatInterleave: + @classmethod + def setup_class(cls): + torch.set_num_threads(1) + + def setup_method(self, method): + torch.manual_seed(0) + np.random.seed(0) + + @pytest.mark.parametrize( + "T,repeat_min,repeat_max", + GRID_TEST_CONFIGS, + ) + @pytest.mark.parametrize("case_idx", range(NUM_CASES_PER_CONFIG)) + def test_cpu_correctness_random_cases( + self, + T: int, + repeat_min: int, + repeat_max: int, + case_idx: int, + ): + torch.manual_seed(case_idx) + np.random.seed(case_idx) + + grid_thw = _generate_random_grid(T, repeat_min, repeat_max) + + grid_clone = grid_thw.clone() + + out_torch = torch_ref_impl(grid_thw) + out_numpy = cpu_numpy_impl(grid_thw) + + assert torch.equal(grid_thw, grid_clone), "Function modified input grid_thw!" + + assert ( + out_torch.shape == out_numpy.shape + ), f"Shape mismatch: torch={out_torch.shape}, numpy={out_numpy.shape}" + + assert ( + out_torch.dtype == torch.int32 + ), f"Unexpected torch dtype: {out_torch.dtype}" + assert ( + out_numpy.dtype == torch.int32 + ), f"Unexpected numpy impl dtype: {out_numpy.dtype}" + + if not torch.equal(out_torch.cpu(), out_numpy.cpu()): + diff_idx = (out_torch.cpu() != out_numpy.cpu()).nonzero(as_tuple=False) + idx0 = diff_idx[0].item() + pytest.fail( + f"Value mismatch, T={T}, case_idx={case_idx}, first differing index={idx0}, " + f"torch={out_torch[idx0].item()}, " + f"numpy={out_numpy[idx0].item()}" ) - th = torch.randint(1, 17, (T, 1), dtype=torch.int32) - tw = torch.randint(1, 17, (T, 1), dtype=torch.int32) - grid_thw = torch.cat([repeats, th, tw], dim=1) - - # Save a copy to ensure functions do not modify inputs - grid_clone = grid_thw.clone() - - out_torch = torch_ref_impl(grid_thw) - out_numpy = cpu_numpy_impl(grid_thw) - - # Input should not be modified - assert torch.equal( - grid_thw, grid_clone - ), "Function modified input grid_thw!" - - # Shapes must be the same - assert ( - out_torch.shape == out_numpy.shape - ), f"Shape mismatch: torch={out_torch.shape}, numpy={out_numpy.shape}" - - # Dtypes must be the same (should both be int32) - assert ( - out_torch.dtype == out_numpy.dtype == torch.int32 - ), f"dtype mismatch: torch={out_torch.dtype}, numpy={out_numpy.dtype}" - - # Values must be exactly the same - if not torch.equal(out_torch.cpu(), out_numpy.cpu()): - diff_idx = (out_torch.cpu() != out_numpy.cpu()).nonzero(as_tuple=False) - idx0 = diff_idx[0].item() - raise AssertionError( - f"Value mismatch, T={T}, first differing index={idx0}, " - f"torch={out_torch[idx0].item()}, " - f"numpy={out_numpy[idx0].item()}" - ) - - print("CPU correctness check: PASSED.") - - -def main(): - # Setting number of threads to reduce noise from thread scheduling; - # you can comment this out if you prefer default behavior. - torch.set_num_threads(1) - - # --------------- Correctness check --------------- - check_correctness_cpu() - print("\nAll correctness checks passed. Starting benchmark...\n") - - # --------------- Performance benchmark --------------- - # Typical scales: - # T = number of rows in grid_thw - # H, W only participate in multiplication - configs = [ - (128, 8, 8), - (512, 8, 8), - (2048, 8, 8), - (8192, 8, 8), - ] - - iters = 2000 # number of iterations per configuration - - print("=== CPU benchmark ===") - for T, H, W in configs: - # Construct grid_thw: [T, 3] - # col0: repeat count - # col1, col2: multiplicative factors - grid_thw = torch.randint(1, 5, (T, 3), dtype=torch.int32) - grid_thw[:, 1] = H - grid_thw[:, 2] = W - - t_torch, out_torch = benchmark_once(torch_ref_impl, grid_thw, iters=iters) - t_numpy, out_numpy = benchmark_once(cpu_numpy_impl, grid_thw, iters=iters) - - # Additional safety check: results should match - same = torch.equal(out_torch.cpu(), out_numpy.cpu()) - - print( - f"[CPU] T={T:5d}, iters={iters:4d} | " - f"torch={t_torch*1e3:7.2f} ms, " - f"numpy={t_numpy*1e3:7.2f} ms, " - f"same={same}" + + def test_zero_repeat_edge_case(self): + T = 4 + grid_thw = torch.tensor( + [ + [0, 4, 4], + [1, 2, 3], # 6 + [2, 1, 5], # 5, 5 + [0, 7, 7], # 0 + ], + dtype=torch.int32, ) + grid_clone = grid_thw.clone() + + out_torch = torch_ref_impl(grid_thw) + out_numpy = cpu_numpy_impl(grid_thw) + + assert torch.equal( + grid_thw, grid_clone + ), "Function modified input grid_thw with zero repeats!" -if __name__ == "__main__": - main() + assert torch.equal( + out_torch.cpu(), out_numpy.cpu() + ), f"Zero-repeat case mismatch: torch={out_torch}, numpy={out_numpy}"