Skip to content

[DLLM] Add JointThreshold algorithm for joint M2T and T2T decoding#18171

Merged
ispobock merged 2 commits intosgl-project:mainfrom
edwardzjl:dllm-editing
Feb 9, 2026
Merged

[DLLM] Add JointThreshold algorithm for joint M2T and T2T decoding#18171
ispobock merged 2 commits intosgl-project:mainfrom
edwardzjl:dllm-editing

Conversation

@edwardzjl
Copy link
Contributor

@edwardzjl edwardzjl commented Feb 3, 2026

Motivation

This PR introduces the JointThreshold algorithm, which enables the model to simultaneously fill in masks and refine previously generated tokens in a single iterative loop.

Modifications

Added JointThreshold Algorithm: Implemented a new DLLM algorithm that supports both Mask-to-Token (M2T) and Token-to-Token (T2T) decoding strategies.

Note on Model Compatibility:

The effectiveness of the T2T (Token-to-Token) refinement depends heavily on the underlying model. This algorithm is most effective when used with models specifically trained or fine-tuned for token to token editing tasks. For standard DLLM models without such training, the T2T component may even harm the final generation quality.

Accuracy Tests

Since the compatible models specifically trained for token to token editing are still awaiting public release, unittests for this commit are currently omitted. I will add test cases once these models are available.

Benchmarking and Profiling

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @edwardzjl, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request integrates a new JointThreshold algorithm into the DLLM framework. This innovative approach aims to improve text generation by unifying Mask-to-Token (M2T) and Token-to-Token (T2T) decoding. The algorithm iteratively fills masked tokens and refines existing ones, offering a more dynamic and potentially higher-quality output, particularly when used with models specifically fine-tuned for T2T editing.

Highlights

  • New Decoding Algorithm: Introduced a new DLLM algorithm named JointThreshold for enhanced text generation.
  • Combined Decoding Strategies: This algorithm uniquely combines Mask-to-Token (M2T) and Token-to-Token (T2T) decoding strategies within a single iterative loop.
  • Iterative Refinement: It allows for simultaneous filling of masked tokens and refinement of previously generated tokens, aiming for more dynamic and accurate outputs.
  • Model Compatibility Note: The effectiveness of the T2T component is highlighted as being highly dependent on the underlying model's specific training or fine-tuning for token editing tasks.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/srt/dllm/algorithm/joint_threshold.py
    • Added a new Python file defining the JointThreshold class, which implements a novel DLLM decoding algorithm.
    • The __init__ method initializes configurable parameters such as threshold, edit_threshold, and max_post_edit_steps.
    • The run method contains the core logic for iterative decoding, handling both M2T (mask filling based on confidence) and T2T (token refinement based on prediction and confidence) operations.
    • Includes logic to manage batch processing, track finished sequences, and determine when to perform additional forward passes for KV cache persistence.
Activity
  • No specific activity (comments, reviews, or progress updates) has been observed for this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the JointThreshold algorithm for joint Mask-to-Token and Token-to-Token decoding. The implementation is clear and follows the described logic. My review includes suggestions to improve performance by vectorizing parts of the code, a correction to a type hint for better code quality, and a recommendation to change a default parameter value to prevent potentially harmful behavior for users, as noted in the pull request description.

):
super().__init__(config)
self.threshold = config.algorithm_config.get("threshold", 0.5)
self.edit_threshold = config.algorithm_config.get("edit_threshold", 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The default value for edit_threshold is 0. Since p is a probability from softmax, it will always be greater than 0. This means the Token-to-Token (T2T) editing is enabled by default and will aggressively replace any generated token that is not the most likely one according to the current logits. The PR description warns that T2T can harm generation quality for models not trained for it. A safer default would be 1.0, which would effectively disable T2T unless explicitly configured by the user with a lower value.

Suggested change
self.edit_threshold = config.algorithm_config.get("edit_threshold", 0)
self.edit_threshold = config.algorithm_config.get("edit_threshold", 1.0)

self,
model_runner: ModelRunner,
forward_batch: ForwardBatch,
) -> tuple[LogitsProcessorOutput | torch.Tensor, torch.Tensor | None, bool]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The return type hint for the run method's second element is torch.Tensor | None, but the implementation returns a list of tensors ([] or next_token_ids_list). This should be updated to list[torch.Tensor] for correctness and consistency with other algorithms like LowConfidence.

Suggested change
) -> tuple[LogitsProcessorOutput | torch.Tensor, torch.Tensor | None, bool]:
) -> tuple[LogitsProcessorOutput | torch.Tensor, list[torch.Tensor], bool]:

Comment on lines +38 to +47
start_list = []
prompt_masks = []
for i in range(batch_size):
block_start = i * self.block_size
block_end = block_start + self.block_size
block_input_ids = forward_batch.input_ids[block_start:block_end]

prompt_mask = block_input_ids != self.mask_id
prompt_masks.append(prompt_mask)
start_list.append(prompt_mask.sum().item())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This loop to compute start_list and prompt_masks can be vectorized for better performance. By reshaping forward_batch.input_ids and using batched tensor operations, you can avoid iterating over the batch size in Python.

        reshaped_input_ids = forward_batch.input_ids.view(batch_size, self.block_size)
        prompt_masks_tensor = reshaped_input_ids != self.mask_id
        prompt_masks = list(torch.unbind(prompt_masks_tensor))
        start_list = prompt_masks_tensor.sum(dim=1).tolist()

Comment on lines 67 to 118
for i in range(batch_size):
if finished[i]:
continue

block_start = i * self.block_size
block_end = block_start + self.block_size

curr_input_ids = forward_batch.input_ids[block_start:block_end]
curr_logits = logits_output.full_logits[block_start:block_end]
curr_prompt_mask = prompt_masks[i]

x = torch.argmax(curr_logits, dim=-1)
p = torch.squeeze(
torch.gather(
F.softmax(curr_logits, dim=-1),
dim=-1,
index=torch.unsqueeze(x, -1),
),
-1,
)

mask_index = curr_input_ids == self.mask_id
has_mask = mask_index.any()

# Mask to token (M2T)
mask_transfer_index = torch.zeros_like(mask_index)
if has_mask:
confidence = torch.where(mask_index, p, -np.inf)
mask_transfer_index = confidence > self.threshold

if not mask_transfer_index.any():
_, select_index = torch.topk(confidence, k=1)
mask_transfer_index[select_index] = True
else:
post_edit_steps[i] += 1
if post_edit_steps[i] > self.max_post_edit_steps:
finished[i] = True
continue

# Token to token (T2T)
edit_mask = ~mask_index & ~curr_prompt_mask
edit_transfer_index = (
(p > self.edit_threshold) & (curr_input_ids != x) & edit_mask
)

transfer_index = mask_transfer_index | edit_transfer_index
if not transfer_index.any():
finished[i] = True
continue

curr_input_ids[transfer_index] = x[transfer_index]
any_changed_in_last_step = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The main logic inside the decoding loop iterates over each sequence in the batch individually. This can be a performance bottleneck. Most of the operations within this loop are tensor-based and could be vectorized to operate on the entire batch at once. This would involve reshaping inputs and using masks to handle per-sequence conditional logic. While more complex to implement, it would significantly improve efficiency.

@ClawSeven
Copy link
Contributor

Please add unit tests for the edit pattern, and include accuracy and performance benchmarks. An analysis using the GSM8K dataset would be sufficient.

@github-actions github-actions bot added the documentation Improvements or additions to documentation label Feb 8, 2026
Co-authored-by: Tiwei Bie <tiwei.btw@antgroup.com>
Signed-off-by: Junlin Zhou <zhoujunlin.zjl@antgroup.com>
@ispobock
Copy link
Collaborator

ispobock commented Feb 8, 2026

/tag-and-rerun-ci

@github-actions github-actions bot added the run-ci label Feb 8, 2026
@ClawSeven
Copy link
Contributor

/rerun-failed-ci

@zhaochenyang20
Copy link
Collaborator

/rerun-failed-ci

@ispobock ispobock merged commit 1465224 into sgl-project:main Feb 9, 2026
58 of 95 checks passed
@wenxuewuhd
Copy link

hi, is it possible have the Accuracy Tests and benchmark result on gsm8k with LLaDA2.1? thx

Johnsonms pushed a commit to Johnsonms/sglang that referenced this pull request Feb 14, 2026
…gl-project#18171)

Signed-off-by: Junlin Zhou <zhoujunlin.zjl@antgroup.com>
Co-authored-by: Tiwei Bie <tiwei.btw@antgroup.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation high priority run-ci

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants

Comments