Skip to content

multimodal: precompute hash for MultimodalDataItem#14354

Merged
Kangyan-Zhou merged 3 commits intosgl-project:mainfrom
openanolis:sufeng-buaa/precompute_hash
Dec 18, 2025
Merged

multimodal: precompute hash for MultimodalDataItem#14354
Kangyan-Zhou merged 3 commits intosgl-project:mainfrom
openanolis:sufeng-buaa/precompute_hash

Conversation

@sufeng-buaa
Copy link
Collaborator

Co-authored-by: Junjie Mao junjie.mao@linux.alibaba.com

Motivation

In VLM scenarios, computing the hash code of a MultimodalDataItem is a relatively time-consuming operation. Currently, this step is implemented in the scheduler, and for large input data, it directly blocks the entire scheduler, resulting in reduced throughput and increased latency.

Modifications

This PR moves this step earlier into the tokenizer and introduces an environment variable SGLANG_MM_PRECOMPUTE_HASH to control whether this feature is enabled.

Accuracy Tests

Benchmarking and Profiling

test parameters

server:

export SGLANG_ENABLE_TORCH_INFERENCE_MODE=true
export OMP_NUM_THREADS=4
export SGLANG_SUPPORT_CUTLASS_BLOCK_FP8=1
export SGLANG_ENABLE_JIT_DEEPGEMM=0
export FLASHINFER_DISABLE_VERSION_CHECK=1
python3 -m sglang.launch_server --model-path Qwen/Qwen3-VL-4B-Instruct-FP8 --enable-multimodal --cuda-graph-max-bs 128 --context-length 2560 --page-size 16 --stream-interval 300 --mem-fraction-static 0.7 --port 30260 --base-gpu-id 0 --kv-cache-dtype fp8_e4m3 --mm-attention-backend sdpa --enable-torch-compile --torch-compile-max-bs 128

client:

vllm bench serve --port 30260 --backend openai-chat --model Qwen/Qwen3-VL-4B-Instruct-FP8 --trust-remote-code --percentile-metrics ttft,tpot,itl,e2el --num-prompts 1024 --dataset-name random-mm --random-mm-base-items-per-request 1 --random-mm-num-mm-items-range-ratio 0 --random-mm-bucket-config '{(1280, 720, 1): 1.0}' --seed 18347 --request-rate inf --random-prefix-len 500 --random-input-len 600 --random-output-len 230 --endpoint /v1/chat/completions --max-concurrency 128

Reuslt

Multimodal data size is approximately 20 MB, achieving a 15% throughput improvement.

Original:

============ Serving Benchmark Result ============
Successful requests:                     1024
Maximum request concurrency:             128
Benchmark duration (s):                  176.54
Total input tokens:                      1125742
Total generated tokens:                  235520
Request throughput (req/s):              5.80
Output token throughput (tok/s):         1334.08
Peak output token throughput (tok/s):    260.00
Peak concurrent requests:                256.00
Total Token throughput (tok/s):          7710.74
---------------Time to First Token----------------
Mean TTFT (ms):                          9802.54
Median TTFT (ms):                        9904.23
P99 TTFT (ms):                           16485.20
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          53.15
Median TPOT (ms):                        52.82
P99 TPOT (ms):                           91.85
---------------Inter-token Latency----------------
Mean ITL (ms):                           4056.81
Median ITL (ms):                         0.01
P99 ITL (ms):                            19928.47
----------------End-to-end Latency----------------
Mean E2EL (ms):                          21972.98
Median E2EL (ms):                        21944.38
P99 E2EL (ms):                           23086.19
==================================================

Optimized:

============ Serving Benchmark Result ============
Successful requests:                     1024
Maximum request concurrency:             128
Benchmark duration (s):                  152.39
Total input tokens:                      1125915
Total generated tokens:                  235520
Request throughput (req/s):              6.72
Output token throughput (tok/s):         1545.52
Peak output token throughput (tok/s):    256.00
Peak concurrent requests:                253.00
Total Token throughput (tok/s):          8933.94
---------------Time to First Token----------------
Mean TTFT (ms):                          7355.01
Median TTFT (ms):                        6660.50
P99 TTFT (ms):                           13171.44
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          50.04
Median TPOT (ms):                        52.04
P99 TPOT (ms):                           77.22
---------------Inter-token Latency----------------
Mean ITL (ms):                           3819.53
Median ITL (ms):                         0.01
P99 ITL (ms):                            17521.99
----------------End-to-end Latency----------------
Mean E2EL (ms):                          18813.59
Median E2EL (ms):                        19404.14
P99 E2EL (ms):                           19809.06
==================================================

Checklist

@github-actions github-actions bot added the documentation Improvements or additions to documentation label Dec 3, 2025
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @sufeng-buaa, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the performance of multimodal data processing in VLM systems by optimizing the timing of hash code generation. By shifting the hash computation for MultimodalDataItem from the scheduler to the tokenizer, the system avoids bottlenecks that previously degraded throughput and increased latency. The change is controlled by a new environment variable, providing flexibility for deployment, and has been shown to yield substantial performance gains.

Highlights

  • Performance Optimization: Addresses a significant performance bottleneck in VLM (Vision-Language Model) scenarios by moving the time-consuming hash computation for MultimodalDataItem from the scheduler to an earlier stage within the tokenizer.
  • Scheduler Efficiency: Prevents the scheduler from being blocked by hash computations, which previously led to reduced throughput and increased latency, especially with large multimodal input data.
  • Configurability: Introduces a new environment variable, SGLANG_MM_PRECOMPUTE_HASH, allowing users to enable or disable this precomputation feature, with a default setting of false.
  • Benchmarked Improvement: Demonstrates a 15% improvement in request throughput for multimodal data approximately 20 MB in size, as validated by detailed benchmarking results.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively moves the hash precomputation for MultimodalDataItem to the tokenizer, which, as the benchmarks show, provides a significant performance improvement by unblocking the scheduler. The changes are well-contained and controlled by a new environment variable.

One concern is that the original hash computation in python/sglang/srt/managers/schedule_batch.py (lines 331-332 in MultimodalInputs.from_dict) seems to be still present based on the full file content. This would cause set_pad_value() to be called twice. Although the implementation of set_pad_value prevents re-hashing, it's best to remove the redundant call from schedule_batch.py to complete the "move" of this logic and keep the code clean. If this was an oversight, please consider adding the change to this file.

I've also left a minor suggestion in tokenizer_manager.py to improve code readability by reducing nesting.

@yuan-luo yuan-luo added Multi-modal multi-modal language model vlm and removed documentation Improvements or additions to documentation labels Dec 3, 2025
@github-actions github-actions bot added the documentation Improvements or additions to documentation label Dec 3, 2025
@sufeng-buaa
Copy link
Collaborator Author

Code Review

This pull request effectively moves the hash precomputation for MultimodalDataItem to the tokenizer, which, as the benchmarks show, provides a significant performance improvement by unblocking the scheduler. The changes are well-contained and controlled by a new environment variable.

One concern is that the original hash computation in python/sglang/srt/managers/schedule_batch.py (lines 331-332 in MultimodalInputs.from_dict) seems to be still present based on the full file content. This would cause set_pad_value() to be called twice. Although the implementation of set_pad_value prevents re-hashing, it's best to remove the redundant call from schedule_batch.py to complete the "move" of this logic and keep the code clean. If this was an oversight, please consider adding the change to this file.

The cost of calling set_pad_value twice is essentially just one extra execution of self.pad_value = self.hash % (1 << 30), which has negligible overhead. However, I believe this is necessary because input data might bypass the tokenizer and be delivered directly to the scheduler through other paths—for example, via gRPC from a model gateway—so the scheduler must perform the check again.

@stmatengss
Copy link
Collaborator

/tag-and-rerun-ci

@github-actions github-actions bot added the run-ci label Dec 4, 2025
@yuan-luo
Copy link
Collaborator

yuan-luo commented Dec 4, 2025

LGTM. @yhyang201 Could you help to review?

Signed-off-by: Feng Su <sufeng@linux.alibaba.com>
Signed-off-by: Junjie Mao <junjie.mao@linux.alibaba.com>
@sufeng-buaa sufeng-buaa force-pushed the sufeng-buaa/precompute_hash branch from 6770b08 to d00c2a0 Compare December 8, 2025 03:11
@JustinTong0323
Copy link
Collaborator

Great! Could you conduct a more in-depth comparison regarding the enabling and disabling of this feature? This should include scenarios with low-bs and high-bs. Maybe we could make it default?

@sufeng-buaa
Copy link
Collaborator Author

Great! Could you conduct a more in-depth comparison regarding the enabling and disabling of this feature? This should include scenarios with low-bs and high-bs. Maybe we could make it default?

  • If the batch size is small and the scheduler runs in overlap mode, and the accumulated hashing time of the batch does not exceed the overlap duration, the scheduler will not be blocked, and requests arriving early can be scheduled promptly.
  • if the batch size is large, or the accumulated hashing time exceeds the overlap time, the scheduler will be blocked, and this blocking time will affect all incoming requests.
  • One drawback I can think of when moving the hashing process into the tokenizer is that it may block the tokenizer itself. If the blocking duration is too long, new requests might arrive during this period, but the tokenizer fails to respond to HTTP requests in a timely manner, potentially causing clients to mistakenly assume that the server has crashed. But I suppose MultiTokenizer could help mitigate this issue?

Should we consult more people before making it the default?

@stmatengss
Copy link
Collaborator

/rerun-failed-ci

@yhyang201
Copy link
Collaborator

/tag-and-rerun-ci

@JustinTong0323
Copy link
Collaborator

LGTM for now, maybe we could add some auto switching feature based on the traffic in the future.

@JustinTong0323
Copy link
Collaborator

/rerun-failed-ci

@sufeng-buaa
Copy link
Collaborator Author

Do we need to wait for the CI to be fixed before we can merge? The failed items don't seem related to my code. @JustinTong0323 @yuan-luo

@yhyang201
Copy link
Collaborator

/rerun-failed-ci

@sufeng-buaa
Copy link
Collaborator Author

@yhyang201 All test cases have passed. Can this PR be merged? Thanks

@yhyang201
Copy link
Collaborator

I would like to verify if my understanding of this feature is correct:

First, this feature offloads the hash calculation to the tokenizer manager. Since the tokenizer manager runs asynchronously with the scheduler, this creates an execution 'overlap', preventing the hash calculation from blocking the scheduler, and thereby improving overall throughput

Second, this feature is controlled via an environment variable because it should not be enabled under high concurrency. The rationale is that if the tokenizer becomes blocked, new requests might fail to establish a connection entirely. In contrast, if the scheduler handles the load, requests would simply be queued. Therefore, enabling this feature under heavy load could prevent requests from entering the system, leading to a high number of timeout failures.

Is this understanding accurate? Please let me know if I missed anything.

@sufeng-buaa
Copy link
Collaborator Author

I would like to verify if my understanding of this feature is correct:

First, this feature offloads the hash calculation to the tokenizer manager. Since the tokenizer manager runs asynchronously with the scheduler, this creates an execution 'overlap', preventing the hash calculation from blocking the scheduler, and thereby improving overall throughput

Second, this feature is controlled via an environment variable because it should not be enabled under high concurrency. The rationale is that if the tokenizer becomes blocked, new requests might fail to establish a connection entirely. In contrast, if the scheduler handles the load, requests would simply be queued. Therefore, enabling this feature under heavy load could prevent requests from entering the system, leading to a high number of timeout failures.

Is this understanding accurate? Please let me know if I missed anything.

yes, your understanding is correct.

@Kangyan-Zhou Kangyan-Zhou merged commit 29e8f7f into sgl-project:main Dec 18, 2025
366 of 397 checks passed
xiaobaicxy added a commit to xiaobaicxy/sglang that referenced this pull request Dec 19, 2025
* 'main' of https://github.com/sgl-project/sglang: (136 commits)
  fix: unreachable error check in retraction (sgl-project#15433)
  [sgl-kernel] chore: update deepgemm version (sgl-project#13402)
  [diffusion] multi-platform: support diffusion on amd and fix encoder loading on MI325 (sgl-project#13760)
  [amd] Add deterministic all-reduce kernel for AMD (ROCm) (sgl-project#15340)
  [diffusion] refactor: refactor _build_req_from_sampling to use shallow_asdict (sgl-project#13782)
  Add customized sampler registration (sgl-project#15423)
  Update readme (sgl-project#15425)
  Fix Mindspore model import warning (sgl-project#15287)
  [Feature] Xiaomi `MiMo-V2-Flash` day0 support (sgl-project#15207)
  [diffusion] profiling: add bench_serving.py and VBench (sgl-project#15410)
  [DLLM] Fix dLLM regression (sgl-project#15371)
  [Deepseek V3.2] Fix Deepseek MTP in V1 mode (sgl-project#15429)
  chore: update CI_PERMISSIONS (sgl-project#15431)
  [DLLM] Add CI for diffusion LLMs (sgl-project#14723)
  Support using different attention backend for draft decoding. (sgl-project#14843)
  feat(dsv32): better error handling for DeepSeek-v3.2 encoder (sgl-project#14353)
  tiny fix lint on main (sgl-project#15424)
  multimodal: precompute hash for MultimodalDataItem (sgl-project#14354)
  [AMD] Clear pre-built AITER kernels and warmup to prevent segfaults and test timeouts (sgl-project#15318)
  [Performance] optimize NSA backend metadata computation for multi-step speculative decoding (sgl-project#14781)
  ...
Prozac614 pushed a commit to Prozac614/sglang that referenced this pull request Dec 23, 2025
Signed-off-by: Feng Su <sufeng@linux.alibaba.com>
Signed-off-by: Junjie Mao <junjie.mao@linux.alibaba.com>
jiaming1130 pushed a commit to zhuyijie88/sglang that referenced this pull request Dec 25, 2025
Signed-off-by: Feng Su <sufeng@linux.alibaba.com>
Signed-off-by: Junjie Mao <junjie.mao@linux.alibaba.com>
YChange01 pushed a commit to YChange01/sglang that referenced this pull request Jan 13, 2026
Signed-off-by: Feng Su <sufeng@linux.alibaba.com>
Signed-off-by: Junjie Mao <junjie.mao@linux.alibaba.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation Multi-modal multi-modal language model run-ci vlm

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants