[DLLM] Add JointThreshold algorithm for joint M2T and T2T decoding#18171
[DLLM] Add JointThreshold algorithm for joint M2T and T2T decoding#18171ispobock merged 2 commits intosgl-project:mainfrom
Conversation
Summary of ChangesHello @edwardzjl, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request integrates a new Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces the JointThreshold algorithm for joint Mask-to-Token and Token-to-Token decoding. The implementation is clear and follows the described logic. My review includes suggestions to improve performance by vectorizing parts of the code, a correction to a type hint for better code quality, and a recommendation to change a default parameter value to prevent potentially harmful behavior for users, as noted in the pull request description.
| ): | ||
| super().__init__(config) | ||
| self.threshold = config.algorithm_config.get("threshold", 0.5) | ||
| self.edit_threshold = config.algorithm_config.get("edit_threshold", 0) |
There was a problem hiding this comment.
The default value for edit_threshold is 0. Since p is a probability from softmax, it will always be greater than 0. This means the Token-to-Token (T2T) editing is enabled by default and will aggressively replace any generated token that is not the most likely one according to the current logits. The PR description warns that T2T can harm generation quality for models not trained for it. A safer default would be 1.0, which would effectively disable T2T unless explicitly configured by the user with a lower value.
| self.edit_threshold = config.algorithm_config.get("edit_threshold", 0) | |
| self.edit_threshold = config.algorithm_config.get("edit_threshold", 1.0) |
| self, | ||
| model_runner: ModelRunner, | ||
| forward_batch: ForwardBatch, | ||
| ) -> tuple[LogitsProcessorOutput | torch.Tensor, torch.Tensor | None, bool]: |
There was a problem hiding this comment.
The return type hint for the run method's second element is torch.Tensor | None, but the implementation returns a list of tensors ([] or next_token_ids_list). This should be updated to list[torch.Tensor] for correctness and consistency with other algorithms like LowConfidence.
| ) -> tuple[LogitsProcessorOutput | torch.Tensor, torch.Tensor | None, bool]: | |
| ) -> tuple[LogitsProcessorOutput | torch.Tensor, list[torch.Tensor], bool]: |
| start_list = [] | ||
| prompt_masks = [] | ||
| for i in range(batch_size): | ||
| block_start = i * self.block_size | ||
| block_end = block_start + self.block_size | ||
| block_input_ids = forward_batch.input_ids[block_start:block_end] | ||
|
|
||
| prompt_mask = block_input_ids != self.mask_id | ||
| prompt_masks.append(prompt_mask) | ||
| start_list.append(prompt_mask.sum().item()) |
There was a problem hiding this comment.
This loop to compute start_list and prompt_masks can be vectorized for better performance. By reshaping forward_batch.input_ids and using batched tensor operations, you can avoid iterating over the batch size in Python.
reshaped_input_ids = forward_batch.input_ids.view(batch_size, self.block_size)
prompt_masks_tensor = reshaped_input_ids != self.mask_id
prompt_masks = list(torch.unbind(prompt_masks_tensor))
start_list = prompt_masks_tensor.sum(dim=1).tolist()| for i in range(batch_size): | ||
| if finished[i]: | ||
| continue | ||
|
|
||
| block_start = i * self.block_size | ||
| block_end = block_start + self.block_size | ||
|
|
||
| curr_input_ids = forward_batch.input_ids[block_start:block_end] | ||
| curr_logits = logits_output.full_logits[block_start:block_end] | ||
| curr_prompt_mask = prompt_masks[i] | ||
|
|
||
| x = torch.argmax(curr_logits, dim=-1) | ||
| p = torch.squeeze( | ||
| torch.gather( | ||
| F.softmax(curr_logits, dim=-1), | ||
| dim=-1, | ||
| index=torch.unsqueeze(x, -1), | ||
| ), | ||
| -1, | ||
| ) | ||
|
|
||
| mask_index = curr_input_ids == self.mask_id | ||
| has_mask = mask_index.any() | ||
|
|
||
| # Mask to token (M2T) | ||
| mask_transfer_index = torch.zeros_like(mask_index) | ||
| if has_mask: | ||
| confidence = torch.where(mask_index, p, -np.inf) | ||
| mask_transfer_index = confidence > self.threshold | ||
|
|
||
| if not mask_transfer_index.any(): | ||
| _, select_index = torch.topk(confidence, k=1) | ||
| mask_transfer_index[select_index] = True | ||
| else: | ||
| post_edit_steps[i] += 1 | ||
| if post_edit_steps[i] > self.max_post_edit_steps: | ||
| finished[i] = True | ||
| continue | ||
|
|
||
| # Token to token (T2T) | ||
| edit_mask = ~mask_index & ~curr_prompt_mask | ||
| edit_transfer_index = ( | ||
| (p > self.edit_threshold) & (curr_input_ids != x) & edit_mask | ||
| ) | ||
|
|
||
| transfer_index = mask_transfer_index | edit_transfer_index | ||
| if not transfer_index.any(): | ||
| finished[i] = True | ||
| continue | ||
|
|
||
| curr_input_ids[transfer_index] = x[transfer_index] | ||
| any_changed_in_last_step = True |
There was a problem hiding this comment.
The main logic inside the decoding loop iterates over each sequence in the batch individually. This can be a performance bottleneck. Most of the operations within this loop are tensor-based and could be vectorized to operate on the entire batch at once. This would involve reshaping inputs and using masks to handle per-sequence conditional logic. While more complex to implement, it would significantly improve efficiency.
|
Please add unit tests for the edit pattern, and include accuracy and performance benchmarks. An analysis using the GSM8K dataset would be sufficient. |
8839f5e to
15b1b1b
Compare
Co-authored-by: Tiwei Bie <tiwei.btw@antgroup.com> Signed-off-by: Junlin Zhou <zhoujunlin.zjl@antgroup.com>
15b1b1b to
4909972
Compare
|
/tag-and-rerun-ci |
|
/rerun-failed-ci |
|
/rerun-failed-ci |
|
hi, is it possible have the Accuracy Tests and benchmark result on gsm8k with LLaDA2.1? thx |
…gl-project#18171) Signed-off-by: Junlin Zhou <zhoujunlin.zjl@antgroup.com> Co-authored-by: Tiwei Bie <tiwei.btw@antgroup.com>
Motivation
This PR introduces the JointThreshold algorithm, which enables the model to simultaneously fill in masks and refine previously generated tokens in a single iterative loop.
Modifications
Added JointThreshold Algorithm: Implemented a new DLLM algorithm that supports both Mask-to-Token (M2T) and Token-to-Token (T2T) decoding strategies.
Note on Model Compatibility:
The effectiveness of the T2T (Token-to-Token) refinement depends heavily on the underlying model. This algorithm is most effective when used with models specifically trained or fine-tuned for token to token editing tasks. For standard DLLM models without such training, the T2T component may even harm the final generation quality.
Accuracy Tests
Since the compatible models specifically trained for token to token editing are still awaiting public release, unittests for this commit are currently omitted. I will add test cases once these models are available.
Benchmarking and Profiling
Checklist
Review Process
/tag-run-ci-label,/rerun-failed-ci,/tag-and-rerun-ci