Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions python/sglang/srt/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2 import Qwen2Model
from sglang.srt.models.utils import compute_cu_seqlens_from_grid_numpy
from sglang.srt.utils import add_prefix
from sglang.srt.utils.hf_transformers_utils import get_processor

Expand Down Expand Up @@ -387,10 +388,7 @@ def forward(
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
position_embeddings = (emb.cos(), emb.sin())
# compute cu_seqlens
cu_seqlens = torch.repeat_interleave(
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
).cumsum(dim=0, dtype=torch.int32)
cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
cu_seqlens = compute_cu_seqlens_from_grid_numpy(grid_thw)

# transformers
x = x.unsqueeze(1)
Expand Down
11 changes: 2 additions & 9 deletions python/sglang/srt/models/qwen3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen3 import Qwen3Model
from sglang.srt.models.utils import compute_cu_seqlens_from_grid_numpy
from sglang.srt.utils import add_prefix
from sglang.srt.utils.hf_transformers_utils import get_processor

Expand Down Expand Up @@ -434,15 +435,7 @@ def forward(
position_embeddings = (emb.cos(), emb.sin())

# compute cu_seqlens
cu_seqlens = torch.repeat_interleave(
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
).cumsum(dim=0)
cu_seqlens = torch.cat(
[
torch.zeros(1, dtype=torch.int32, device=cu_seqlens.device),
cu_seqlens.to(torch.int32),
]
)
cu_seqlens = compute_cu_seqlens_from_grid_numpy(grid_thw)

x = x.unsqueeze(1)

Expand Down
23 changes: 23 additions & 0 deletions python/sglang/srt/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# limitations under the License.
# ==============================================================================

import numpy as np
import torch

from sglang.srt.layers.radix_attention import RadixAttention
Expand Down Expand Up @@ -59,3 +60,25 @@ def permute_inv(perm: torch.Tensor) -> torch.Tensor:
inv_perm = torch.empty_like(perm)
inv_perm[perm] = torch.arange(perm.numel(), device=perm.device, dtype=perm.dtype)
return inv_perm


def compute_cu_seqlens_from_grid_numpy(grid_thw: torch.Tensor) -> torch.Tensor:
"""
Compute cu_seqlens from grid_thw using NumPy.

grid_thw: [T, 3] int tensor on CPU.
columns: [repeat_count, H, W]
Returns:
cu_seqlens: 1D int32 tensor on CPU, shape [N + 1]
"""
assert (
grid_thw.device.type == "cpu"
), "compute_cu_seqlens_from_grid_numpy expects a CPU tensor"
arr = grid_thw.numpy()

cu_seqlens = np.repeat(arr[:, 1] * arr[:, 2], arr[:, 0]).cumsum(
axis=0, dtype=np.int32
)
cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens])
cu_seqlens = torch.from_numpy(cu_seqlens)
return cu_seqlens
141 changes: 141 additions & 0 deletions test/srt/ops/test_repeat_interleave.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import time
from typing import Tuple

import numpy as np
import pytest
import torch

from sglang.srt.models.utils import compute_cu_seqlens_from_grid_numpy as cpu_numpy_impl


def torch_ref_impl(grid_thw: torch.Tensor) -> torch.Tensor:
"""
Pure PyTorch implementation of cu_seqlens computation.
Assumes grid_thw is already on the correct device (CPU here).
Shape: [T, 3], columns: [repeat_count, H, W]
"""
cu_seqlens = torch.repeat_interleave(
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
).cumsum(dim=0)
cu_seqlens = torch.cat(
[
torch.zeros(1, dtype=torch.int32, device=cu_seqlens.device),
cu_seqlens.to(torch.int32),
]
)
return cu_seqlens


def benchmark_once(fn, grid_thw, iters: int = 1000):
"""
Run a function `fn` on the same input `grid_thw` for `iters` times
and measure total elapsed time.
"""
start = time.perf_counter()
for _ in range(iters):
out = fn(grid_thw)
end = time.perf_counter()
return (end - start), out


# (T, repeat_min, repeat_max)
GRID_TEST_CONFIGS: list[Tuple[int, int, int]] = [
(16, 1, 4), # small T, small repeat counts
(128, 0, 4), # allow repeat=0 to test edge cases
(512, 1, 8),
(1024, 1, 16),
]

NUM_CASES_PER_CONFIG = 10


def _generate_random_grid(T: int, repeat_min: int, repeat_max: int) -> torch.Tensor:
"""
grid_thw: [T, 3]
col0: repeat count
col1, col2: arbitrary positive integers (here 1..16)
"""
repeats = torch.randint(repeat_min, repeat_max + 1, (T, 1), dtype=torch.int32)
th = torch.randint(1, 17, (T, 1), dtype=torch.int32)
tw = torch.randint(1, 17, (T, 1), dtype=torch.int32)
grid_thw = torch.cat([repeats, th, tw], dim=1)
return grid_thw


class TestRepeatInterleave:
@classmethod
def setup_class(cls):
torch.set_num_threads(1)

def setup_method(self, method):
torch.manual_seed(0)
np.random.seed(0)

@pytest.mark.parametrize(
"T,repeat_min,repeat_max",
GRID_TEST_CONFIGS,
)
@pytest.mark.parametrize("case_idx", range(NUM_CASES_PER_CONFIG))
def test_cpu_correctness_random_cases(
self,
T: int,
repeat_min: int,
repeat_max: int,
case_idx: int,
):
torch.manual_seed(case_idx)
np.random.seed(case_idx)

grid_thw = _generate_random_grid(T, repeat_min, repeat_max)

grid_clone = grid_thw.clone()

out_torch = torch_ref_impl(grid_thw)
out_numpy = cpu_numpy_impl(grid_thw)

assert torch.equal(grid_thw, grid_clone), "Function modified input grid_thw!"

assert (
out_torch.shape == out_numpy.shape
), f"Shape mismatch: torch={out_torch.shape}, numpy={out_numpy.shape}"

assert (
out_torch.dtype == torch.int32
), f"Unexpected torch dtype: {out_torch.dtype}"
assert (
out_numpy.dtype == torch.int32
), f"Unexpected numpy impl dtype: {out_numpy.dtype}"

if not torch.equal(out_torch.cpu(), out_numpy.cpu()):
diff_idx = (out_torch.cpu() != out_numpy.cpu()).nonzero(as_tuple=False)
idx0 = diff_idx[0].item()
pytest.fail(
f"Value mismatch, T={T}, case_idx={case_idx}, first differing index={idx0}, "
f"torch={out_torch[idx0].item()}, "
f"numpy={out_numpy[idx0].item()}"
)

def test_zero_repeat_edge_case(self):
T = 4
grid_thw = torch.tensor(
[
[0, 4, 4],
[1, 2, 3], # 6
[2, 1, 5], # 5, 5
[0, 7, 7], # 0
],
dtype=torch.int32,
)

grid_clone = grid_thw.clone()

out_torch = torch_ref_impl(grid_thw)
out_numpy = cpu_numpy_impl(grid_thw)

assert torch.equal(
grid_thw, grid_clone
), "Function modified input grid_thw with zero repeats!"

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please refactor the test into a unit-test style, keeping it consistent with the other tests under the srt directory.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactored test to unit-test style.

$python -m pytest ./test/srt/ops/test_repeat_interleave.py
/opt/conda/lib/python3.10/site-packages/pytest_asyncio/plugin.py:252: PytestDeprecationWarning: The configuration option "asyncio_default_fixture_loop_scope" is unset.
The event loop scope for asynchronous fixtures will default to the fixture caching scope. Future versions of pytest-asyncio will default the loop scope for asynchronous fixtures to function scope. Set the default fixture loop scope explicitly in order to avoid unexpected behavior in the future. Valid fixture loop scopes are: "function", "class", "module", "package", "session"

  warnings.warn(PytestDeprecationWarning(_DEFAULT_FIXTURE_LOOP_SCOPE_UNSET))
============================================================================================= test session starts =============================================================================================
platform linux -- Python 3.10.13, pytest-8.3.5, pluggy-1.5.0
rootdir: /root/luoyuan.luo/workspace/sglang_dev
plugins: anyio-4.11.0, asyncio-1.2.0
asyncio: mode=strict, debug=False, asyncio_default_fixture_loop_scope=None, asyncio_default_test_loop_scope=function
collected 41 items                                                                                                                                                                                            

test_repeat_interleave.py .........................................                                                                                                                                     [100%]

============================================================================================== warnings summary ===============================================================================================
<frozen importlib._bootstrap>:241
  <frozen importlib._bootstrap>:241: DeprecationWarning: builtin type SwigPyPacked has no __module__ attribute

<frozen importlib._bootstrap>:241
  <frozen importlib._bootstrap>:241: DeprecationWarning: builtin type SwigPyObject has no __module__ attribute

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
======================================================================================= 41 passed, 2 warnings in 6.46s ========================================================================================
sys:1: DeprecationWarning: builtin type swigvarlink has no __module__ attribute

assert torch.equal(
out_torch.cpu(), out_numpy.cpu()
), f"Zero-repeat case mismatch: torch={out_torch}, numpy={out_numpy}"
1 change: 1 addition & 0 deletions test/srt/run_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
TestFile("openai_server/validation/test_matched_stop.py", 60),
TestFile("openai_server/validation/test_openai_server_ignore_eos.py", 85),
TestFile("openai_server/validation/test_request_length_validation.py", 31),
TestFile("ops/test_repeat_interleave.py", 60),
TestFile("quant/test_block_int8.py", 22),
TestFile("quant/test_fp8_kernel.py", 8),
TestFile("quant/test_int8_kernel.py", 8),
Expand Down
Loading