[sgl-kernel][Deepseek V3.2] Add row_starts to topk kernel#12582
[sgl-kernel][Deepseek V3.2] Add row_starts to topk kernel#12582Fridge003 merged 4 commits intosgl-project:mainfrom
Conversation
Summary of ChangesHello @hlu1, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses a critical issue in the 'topk' kernel's handling of ragged inputs, which is essential for the Deepseek V3.2 model's prefill stage. By integrating a 'row_starts' parameter, the kernel can now accurately determine the starting indices for each row of a ragged tensor, ensuring correct top-k computation. This enhancement improves the robustness and correctness of the model's operations, directly contributing to resolving a reported bug. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request correctly adds the row_starts parameter to the top-k kernels to support ragged tensors, which is a good improvement. The changes are propagated through the C++ kernels, PyTorch op definitions, and Python wrappers. My review has identified a few areas for improvement: there are a couple of missing input validation checks in the C++ code, a minor typo in a Python assertion message, and most importantly, the tests for the new functionality are incomplete as they don't cover cases with non-zero row_starts. I've provided suggestions to address these points.
That's exactly how I tested it. The 20shots gsm8k results are now matching vllm results in https://docs.vllm.ai/projects/recipes/en/latest/DeepSeek/DeepSeek-V3_2-Exp.html#additional-resources |
|
@hlu1 We can add the 20-shot test to CI test after your PRs are merged, for testing long context accuracy |
Signed-off-by: Hao Lu <14827759+hlu1@users.noreply.github.com>
Motivation
Part1 of the fix for bug in #11629
Modifications
In the topk kernel for prefill, the q and k inputs are both ragged. We need to pass the correct start indices of k for each q token to the kernel.
In
fast_topk_transform_interface, I changed theis_decodecriterion fromprefill_bs == Bto!row_starts_opt.has_value() and prefill_bs == B. Usingprefill_bs == Bcan be a bit ambiguous for a single token.Accuracy Tests
8xB200, fp8 checkpoint, bf16 kvcache, radix-cache on:
fp8 kvcache has similar score.
Checklist