Skip to content

Comments

Jet-Nemotron — EAGLE3 + Varlen Dynamic Conv#13025

Open
alex-t-hu wants to merge 2 commits intosgl-project:mainfrom
alex-t-hu:faster_jet_nemotron
Open

Jet-Nemotron — EAGLE3 + Varlen Dynamic Conv#13025
alex-t-hu wants to merge 2 commits intosgl-project:mainfrom
alex-t-hu:faster_jet_nemotron

Conversation

@alex-t-hu
Copy link

@alex-t-hu alex-t-hu commented Nov 11, 2025

Motivation

Built off this PR: #12448
This PR makes the existing JetNemotron implementation faster, and implements speculative decoding for JetNemotron.

Modifications

Topk>1 and topk=1 tree verification kernel for DynamicConv
Varlen prefill kernel for DynamicConv
Decode kernel for DynamicConv

Accuracy Tests

1 RTX A6000, cuda graph on:

Model + Inference Type GSM8K MMLU
Jet-Nemotron-2B 0.763 0.622
Jet-Nemotron-2B + EAGLE3 (topk=1, num_draft_tokens=6, num_steps=5) 0.766 0.621
Jet-Nemotron-2B + EAGLE3 (topk=4, num_draft_tokens=8, num_steps=5) 0.757 0.621
Jet-Nemotron-2B + EAGLE3 v2 (topk=1, num_draft_tokens=6, num_steps=5) 0.787 0.606
Model + Inference Type MMMU
Jet-Nemotron-2B-vlm 0.371
Jet-Nemotron-2B-vlm + EAGLE3 (topk=4, num_draft_tokens=8, num_steps=5) 0.371

Benchmarking and Profiling

New implementation (no spec dec):

image

Original implementation:

image

Speculative Decoding result

We trained an EAGLE3 checkpoint for jet-nemotron-2b-instruct and evaluated on mtbench
image

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @alex-t-hu, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request focuses on enhancing the performance and capabilities of the JetNemotron model, a hybrid architecture. The core objective is to accelerate inference through optimized dynamic convolution kernels and the introduction of speculative decoding. These changes aim to provide a more efficient and versatile model, as evidenced by significant throughput improvements in various benchmarking scenarios.

Highlights

  • Performance Improvement: The JetNemotron model implementation has been significantly optimized, leading to substantial speedups. Benchmarks show an increase from 664 token/s to 1991 token/s without speculative decoding, and 812-926 token/s with speculative decoding enabled.
  • Speculative Decoding Support: Speculative decoding has been implemented for the JetNemotron model, allowing for faster inference by generating draft tokens and verifying them against the target model.
  • New DynamicConv Kernels: New Triton kernels have been introduced for DynamicConv, specifically for topk>1 and topk=1 tree verification, variable-length prefill, and decode operations, enhancing efficiency and flexibility.
  • Refactored Model Architecture: The internal architecture of the JetNemotron model, including JetBlock, JetNemotronAttention, and JetNemotronMLP, has been refactored to integrate the new dynamic convolution and speculative decoding logic more effectively.
  • Eagle3 Integration: Dedicated components (JetNemotronDecoderLayerEagle3, JetNemotronModelEagle3, JetNemotronForCausalLMEagle3) have been added to support the Eagle3 architecture, including specific handling for midlayer types and hidden state capturing.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces significant performance optimizations for the JetNemotron model by implementing custom Triton kernels for dynamic convolution. It also includes a major refactoring of the model implementation, making it more modular and aligned with HuggingFace standards. The introduction of speculative decoding support is another key feature. Overall, the changes are well-structured and the performance gains reported are impressive.

I have a few comments on potential issues, including a critical bug in one of the new Triton kernels, a performance bottleneck in a Python loop, and a couple of medium-severity issues related to code safety and clarity. Please see the detailed comments below.

Comment on lines +191 to +209
for i in range(B):
start_idx = cu_seqlens[i].item()
end_idx = cu_seqlens[i + 1].item()
cache_idx = cache_indices[i].item()
if end_idx - start_idx >= W - 1:
cache[cache_idx, :, 1:] = x[
end_idx - W + 1 : end_idx
].transpose(0, 1)
else:
num_beginning = W - 1 - (end_idx - start_idx)
if has_initial_state[i].item():
cache[cache_idx, :, 1 : num_beginning + 1] = cache[
cache_idx, :, -num_beginning:
]
else:
cache[cache_idx, :, 1 : num_beginning + 1] = 0
cache[cache_idx, :, num_beginning + 1 :] = x[
start_idx:end_idx
].transpose(0, 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This Python loop for updating the convolution cache state in varlen prefill mode can be a significant performance bottleneck. It iterates over the batch size and contains .item() calls, which cause CPU-GPU synchronization. This logic should be moved into a dedicated CUDA kernel (e.g., using Triton) to avoid the synchronization overhead and to perform the updates in parallel on the GPU.

@alex-t-hu alex-t-hu changed the title Faster jet nemotron Jet-Nemotron Speculative Decoding + Varlen Dynamic Conv. Faster. Nov 11, 2025
@alex-t-hu alex-t-hu changed the title Jet-Nemotron Speculative Decoding + Varlen Dynamic Conv. Faster. Faster Jet-Nemotron — EAGLE3 + Varlen Dynamic Conv Nov 11, 2025
@alex-t-hu alex-t-hu changed the title Faster Jet-Nemotron — EAGLE3 + Varlen Dynamic Conv Jet-Nemotron — EAGLE3 + Varlen Dynamic Conv Nov 11, 2025
@alex-t-hu
Copy link
Author

@zhaochenyang20 can you review?

@futrime
Copy link
Contributor

futrime commented Nov 12, 2025

Nice job!

@futrime
Copy link
Contributor

futrime commented Nov 13, 2025

There may be some subsequent modifications in #12448 after this PR basing on it like replacing nn.Linear with SGLang's parallel linear implementation. Could you please check the diff again and minimize modifications?

@alex-t-hu
Copy link
Author

@futrime thank you very much! i made it more consistent with your implementation. after combining JetBlock's q, k, v, a, b, g projections into qkvabz_proj i need to remove fused_gdn_gating to preserve model performance.

@attack204
Copy link
Contributor

Hi, SGLANG's Spec framework has recently implemented a higher-performance EAGLE v2 (based on a better overlap mechanism).
#11398
I think it would be best to adapt to it, which typically requires almost no changes.

You can enable it by using export SGLANG_ENABLE_SPEC_V2=1, and you'll typically get better performance results.

You can refer to the sglang.srt.speculative.eagle_worker_v2.EAGLEWorkerV2 class to understand how it works.

@alex-t-hu
Copy link
Author

alex-t-hu commented Nov 15, 2025

thank you @attack204

there's this bug when running speculative decoding v2?

#13352

so maybe we can just merge this in first? and then fix the bug for everyone later?

@attack204
Copy link
Contributor

thank you @attack204

there's this bug when running speculative decoding v2?

#13352

so maybe we can just merge this in first? and then fix the bug for everyone later?

top k > 1 in eagle v2 is still in developing #11839

so i think test of top k = 1 is enough

@alex-t-hu
Copy link
Author

alex-t-hu commented Nov 17, 2025

thank you @attack204
there's this bug when running speculative decoding v2?
#13352
so maybe we can just merge this in first? and then fix the bug for everyone later?

top k > 1 in eagle v2 is still in developing #11839

so i think test of top k = 1 is enough

thank you @attack204

i see there is some differencies in gsm8k and mmlu accuracy when i use v2 spec dec over v1

@github-actions github-actions bot added the Multi-modal multi-modal language model label Nov 18, 2025
@hnyls2002 hnyls2002 self-assigned this Nov 20, 2025
Copy link
Collaborator

@yizhang2077 yizhang2077 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add some ut for eagle3+Jet-Nemotron?

self.target_worker.model_runner.jet_nemotron_config is not None
and spec_info.topk == 1
):
self.target_worker.model_runner.attn_backend.update_jet_nemotron_topk1_state_after_mtp_verify(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why jet_nemotron topk=1 logic is different? the same as eagle_worker_v2.py

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think for topk=1, we can make memory layout special to improve performance by reduce amount of memory transfer

size=(
num_mamba_layers,
size + 1,
speculative_num_draft_tokens + conv_shape[1] - 2,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why topk=1 in jet_nemotron memory layout is special

logger.info(f"Using hybrid linear attention backend for hybrid GDN models.")
linear_attn_backend = GDNAttnBackend(runner)
if runner.jet_nemotron_config is not None:
linear_attn_backend = JetNemotronAttnBackend(runner)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we could separate jet_nemotron out of hybrid_gdn_config?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this could lead to more code changes since hybrid_gdn_config appears in a lot of places?

0, bs + 1, dtype=torch.int32, device=self.device
)
elif forward_batch.forward_mode.is_extend():
elif forward_batch.forward_mode.is_extend(include_draft_extend_v2=True):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if hybrid linear attention is compatible with eagle v2, do you have any idea? @hebiao064

else:
verified_id = torch.empty((0,), device=self.device, dtype=torch.int32)

if self.target_worker.model_runner.hybrid_gdn_config is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we need to wrap it into another function?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice yup did that

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants