Jet-Nemotron — EAGLE3 + Varlen Dynamic Conv#13025
Jet-Nemotron — EAGLE3 + Varlen Dynamic Conv#13025alex-t-hu wants to merge 2 commits intosgl-project:mainfrom
Conversation
Summary of ChangesHello @alex-t-hu, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request focuses on enhancing the performance and capabilities of the JetNemotron model, a hybrid architecture. The core objective is to accelerate inference through optimized dynamic convolution kernels and the introduction of speculative decoding. These changes aim to provide a more efficient and versatile model, as evidenced by significant throughput improvements in various benchmarking scenarios. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces significant performance optimizations for the JetNemotron model by implementing custom Triton kernels for dynamic convolution. It also includes a major refactoring of the model implementation, making it more modular and aligned with HuggingFace standards. The introduction of speculative decoding support is another key feature. Overall, the changes are well-structured and the performance gains reported are impressive.
I have a few comments on potential issues, including a critical bug in one of the new Triton kernels, a performance bottleneck in a Python loop, and a couple of medium-severity issues related to code safety and clarity. Please see the detailed comments below.
| for i in range(B): | ||
| start_idx = cu_seqlens[i].item() | ||
| end_idx = cu_seqlens[i + 1].item() | ||
| cache_idx = cache_indices[i].item() | ||
| if end_idx - start_idx >= W - 1: | ||
| cache[cache_idx, :, 1:] = x[ | ||
| end_idx - W + 1 : end_idx | ||
| ].transpose(0, 1) | ||
| else: | ||
| num_beginning = W - 1 - (end_idx - start_idx) | ||
| if has_initial_state[i].item(): | ||
| cache[cache_idx, :, 1 : num_beginning + 1] = cache[ | ||
| cache_idx, :, -num_beginning: | ||
| ] | ||
| else: | ||
| cache[cache_idx, :, 1 : num_beginning + 1] = 0 | ||
| cache[cache_idx, :, num_beginning + 1 :] = x[ | ||
| start_idx:end_idx | ||
| ].transpose(0, 1) |
There was a problem hiding this comment.
This Python loop for updating the convolution cache state in varlen prefill mode can be a significant performance bottleneck. It iterates over the batch size and contains .item() calls, which cause CPU-GPU synchronization. This logic should be moved into a dedicated CUDA kernel (e.g., using Triton) to avoid the synchronization overhead and to perform the updates in parallel on the GPU.
|
@zhaochenyang20 can you review? |
|
Nice job! |
|
There may be some subsequent modifications in #12448 after this PR basing on it like replacing nn.Linear with SGLang's parallel linear implementation. Could you please check the diff again and minimize modifications? |
|
@futrime thank you very much! i made it more consistent with your implementation. after combining JetBlock's q, k, v, a, b, g projections into |
|
Hi, SGLANG's Spec framework has recently implemented a higher-performance EAGLE v2 (based on a better overlap mechanism). You can enable it by using export SGLANG_ENABLE_SPEC_V2=1, and you'll typically get better performance results. You can refer to the sglang.srt.speculative.eagle_worker_v2.EAGLEWorkerV2 class to understand how it works. |
|
thank you @attack204 there's this bug when running speculative decoding v2? so maybe we can just merge this in first? and then fix the bug for everyone later? |
top k > 1 in eagle v2 is still in developing #11839 so i think test of top k = 1 is enough |
5100d19 to
92ac85d
Compare
thank you @attack204 i see there is some differencies in gsm8k and mmlu accuracy when i use v2 spec dec over v1 |
92ac85d to
481963c
Compare
a617a8b to
f2be8d3
Compare
fceb100 to
7cad0ea
Compare
yizhang2077
left a comment
There was a problem hiding this comment.
Could we add some ut for eagle3+Jet-Nemotron?
| self.target_worker.model_runner.jet_nemotron_config is not None | ||
| and spec_info.topk == 1 | ||
| ): | ||
| self.target_worker.model_runner.attn_backend.update_jet_nemotron_topk1_state_after_mtp_verify( |
There was a problem hiding this comment.
why jet_nemotron topk=1 logic is different? the same as eagle_worker_v2.py
There was a problem hiding this comment.
i think for topk=1, we can make memory layout special to improve performance by reduce amount of memory transfer
| size=( | ||
| num_mamba_layers, | ||
| size + 1, | ||
| speculative_num_draft_tokens + conv_shape[1] - 2, |
There was a problem hiding this comment.
why topk=1 in jet_nemotron memory layout is special
| logger.info(f"Using hybrid linear attention backend for hybrid GDN models.") | ||
| linear_attn_backend = GDNAttnBackend(runner) | ||
| if runner.jet_nemotron_config is not None: | ||
| linear_attn_backend = JetNemotronAttnBackend(runner) |
There was a problem hiding this comment.
maybe we could separate jet_nemotron out of hybrid_gdn_config?
There was a problem hiding this comment.
this could lead to more code changes since hybrid_gdn_config appears in a lot of places?
| 0, bs + 1, dtype=torch.int32, device=self.device | ||
| ) | ||
| elif forward_batch.forward_mode.is_extend(): | ||
| elif forward_batch.forward_mode.is_extend(include_draft_extend_v2=True): |
There was a problem hiding this comment.
I am not sure if hybrid linear attention is compatible with eagle v2, do you have any idea? @hebiao064
| else: | ||
| verified_id = torch.empty((0,), device=self.device, dtype=torch.int32) | ||
|
|
||
| if self.target_worker.model_runner.hybrid_gdn_config is not None: |
There was a problem hiding this comment.
maybe we need to wrap it into another function?
Motivation
Built off this PR: #12448
This PR makes the existing JetNemotron implementation faster, and implements speculative decoding for JetNemotron.
Modifications
Topk>1 and topk=1 tree verification kernel for DynamicConv
Varlen prefill kernel for DynamicConv
Decode kernel for DynamicConv
Accuracy Tests
1 RTX A6000, cuda graph on:
Benchmarking and Profiling
New implementation (no spec dec):
Original implementation:
Speculative Decoding result
We trained an EAGLE3 checkpoint for jet-nemotron-2b-instruct and evaluated on mtbench
