Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 21 additions & 18 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,24 +321,27 @@ def merge(self, other: MultimodalInputs):
"""
merge image inputs when requests are being merged
"""

# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
# Please note that if the `input_ids` is later used in the model forward,
# you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
# errors in cuda kernels. See also llava.py for example.

# args needed to be merged
optional_args = [
"items",
"image_offsets",
"image_pad_len",
# "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images
]
for arg in optional_args:
self_arg = getattr(self, arg, None)
if self_arg is not None:
setattr(self, arg, self_arg + getattr(other, arg))
# other args would be kept intact
# Merge mm_items
if not self.mm_items:
self.mm_items = other.mm_items
else:
self.mm_items.extend(other.mm_items)

# Merge image_pad_len if exists
if self.image_pad_len is not None and other.image_pad_len is not None:
self.image_pad_len.extend(other.image_pad_len)
elif other.image_pad_len is not None:
self.image_pad_len = other.image_pad_len

# Merge num_image_tokens if exists
if self.num_image_tokens is not None and other.num_image_tokens is not None:
self.num_image_tokens += other.num_image_tokens
elif other.num_image_tokens is not None:
self.num_image_tokens = other.num_image_tokens

# Merge mrope_position_delta if exists
if other.mrope_position_delta is not None:
self.mrope_position_delta = other.mrope_position_delta


class Req:
Expand Down
46 changes: 43 additions & 3 deletions python/sglang/srt/model_executor/forward_batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,21 +417,61 @@ def _compute_mrope_positions(
] * 3
else:
# TODO: current qwen2-vl do not support radix cache since mrope position calculation
image_items = [
item
for item in multimodal_inputs.mm_items
if item.image_grid_thws is not None
]
image_grid_thw = (
torch.concat(
[item.image_grid_thws for item in image_items],
dim=0,
)
if image_items
else None
)

video_items = [
item
for item in multimodal_inputs.mm_items
if item.video_grid_thws is not None
]
video_grid_thw = (
torch.concat(
[item.video_grid_thws for item in video_items],
dim=0,
)
if video_items
else None
)

second_per_grid_ts = (
torch.concat(
[
item.second_per_grid_ts
for item in multimodal_inputs.mm_items
],
dim=0,
)
if video_items
else None
)

mrope_positions, mrope_position_delta = (
MRotaryEmbedding.get_input_positions(
input_tokens=self.input_ids[
extend_start_loc : extend_start_loc + extend_seq_len
],
image_grid_thw=multimodal_inputs.image_grid_thws,
video_grid_thw=multimodal_inputs.video_grid_thws,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
image_token_id=multimodal_inputs.im_token_id,
video_token_id=multimodal_inputs.video_token_id,
vision_start_token_id=hf_config.vision_start_token_id,
vision_end_token_id=hf_config.vision_end_token_id,
spatial_merge_size=hf_config.vision_config.spatial_merge_size,
context_len=0,
seq_len=len(self.input_ids),
second_per_grid_ts=multimodal_inputs.second_per_grid_ts,
second_per_grid_ts=second_per_grid_ts,
tokens_per_second=hf_config.vision_config.tokens_per_second,
)
)
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1060,7 +1060,7 @@ def model_is_mrope(self) -> bool:
rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {})
if rope_scaling is None:
return False
return rope_scaling.get("type", None) == "mrope"
return "mrope_section" in rope_scaling

def save_remote_model(self, url: str):
from sglang.srt.model_loader.loader import RemoteModelLoader
Expand Down
8 changes: 6 additions & 2 deletions python/sglang/srt/models/qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,14 +553,18 @@ def forward(
otherwise it will be `(seq_len,).
(Use input_metadata.mrope_positions to replace it)
"""
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
if "mrope_section" in self.config.rope_scaling:
positions = forward_batch.mrope_positions
assert positions.ndim == 2 and positions.size(0) == 3, (
"multimodal section rotary embedding requires "
f"(3, seq_len) positions, but got {positions.size()}"
)

if not (
forward_batch.forward_mode.is_decode()
or not forward_batch.contains_image_inputs()
):
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
if "mrope_section" in self.config.rope_scaling:
assert positions.ndim == 2 and positions.size(0) == 3, (
"multimodal section rotary embedding requires "
f"(3, seq_len) positions, but got {positions.size()}"
Expand Down