Integrate Multi-Token Prediction (MTP) Training objective#1837
Merged
copybara-service[bot] merged 1 commit intomainfrom Jul 14, 2025
Merged
Integrate Multi-Token Prediction (MTP) Training objective#1837copybara-service[bot] merged 1 commit intomainfrom
copybara-service[bot] merged 1 commit intomainfrom
Conversation
da99edb to
affadb8
Compare
59ddf3f to
27ae66f
Compare
ccebc8c to
b1b2d95
Compare
cfc3476 to
d1be656
Compare
gobbleturk
approved these changes
Jul 10, 2025
Collaborator
gobbleturk
left a comment
There was a problem hiding this comment.
Awesome tests!
Overall it looks great, I left a few nits/code style changes
RissyRan
reviewed
Jul 10, 2025
RissyRan
reviewed
Jul 10, 2025
Collaborator
RissyRan
left a comment
There was a problem hiding this comment.
Shall we add general MTP support into announcement?
Collaborator
Author
That is a great suggestion. I have updated the announcement section. |
cec106a to
f6967f5
Compare
4e14e74 to
4c39e27
Compare
614af6d to
f7fced1
Compare
f7fced1 to
4d9c654
Compare
4d9c654 to
441e788
Compare
Merged
4 tasks
4 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Dependency: This PR depends on and must be merged after the refactoring in Refactor: Decouple Core Transformer Blocks #1852.
PR: Multi-Token Prediction (MTP) Integration
TL;DR
MultiTokenPredictionBlockthat runs after the main decoder stack during training. It computes an additional loss term which is added to the main loss.Detailed Description
Background and Motivation
Standard language models are trained on a next-token prediction objective. Multi-Token Prediction (MTP) enhances this by adding an auxiliary task: from each position in a sequence, the model also learns to predict several tokens into the future. This encourages the model to develop richer internal representations and can lead to significant improvements in sample efficiency and final model performance.
This implementation follows the sequential prediction model, where the prediction of token
t+k+1is causally dependent on the layer that predicted tokent+k.Architectural Changes
To integrate this feature cleanly and robustly, several key architectural changes were made:
The MTP Module (
layers/multi_token_prediction.py) A new file was created to house all MTP-specific logic:MultiTokenPredictionLayer: A single block responsible for one step of future prediction. It normalizes its inputs, projects them, and processes them through a standard transformer layer.MultiTokenPredictionBlock: This module orchestrates the entire MTP process. It contains aforloop that runs formtp_num_layers, instantiating a uniqueMultiTokenPredictionLayerfor each step and maintaining the sequential flow of the hidden state.Integration with
Transformer(layers/models.py) The mainTransformermodel was modified to facilitate the MTP "side-car":Decoder's__call__method now returns both themain_logitsand the rawfinal_hidden_state(pre-normalization). This makes the dependency explicit.Transformer'ssetupmethod now instantiates theMultiTokenPredictionBlock, passing it the correctDecoderLayerblueprint to ensure architectural consistency.Transformer's__call__method calls theMTPBlockonly during training (model_mode == MODEL_MODE_TRAIN), explicitly passing it the dependencies it needs (final_hidden_state,shared_embedding, etc.).Loss Calculation (
train.py) The auxiliary loss is aggregated without changing theTransformer's return signature by using Flax'ssowmechanism:MultiTokenPredictionBlockcallsself.sow('mtp_losses', 'losses', ...)for each layer's calculated loss. This is guarded by aif not self.is_initializing()check to prevent running during model initialization.loss_fnintrain.pyis now responsible for "reaping" these values by making the'mtp_losses'collection mutable during the training.applycall.maxtext_utils.get_nested_valueutility and the explicit path (mtp_losses,mtp_block,losses).mtp_loss_scaling_factor, and adds it to the main loss before backpropagation. Themtp_lossis also added to the training and evaluation metrics for logging.Configuration
This feature is controlled via two new parameters in
base.yml:mtp_num_layers: (int, default:0) The number of auxiliary prediction layers to use. Set to a positive integer to enable MTP.mtp_loss_scaling_factor: (float, default:0.1) The weighting factor for the final MTP loss.How to Use: To enable MTP with 4 prediction heads and a 15% loss contribution, add the following to your config file:
YAML
Testing Strategy
A new test file,
MaxText/tests/multi_token_prediction_test.py, has been added with a new test class,MultiTokenPredictionBlockTest, to ensure the implementation is robust.The testing follows these key principles:
MTPBlockTestModelis used to wrap theMultiTokenPredictionBlockand its dependencies. This allows Flax's.init()to handle all parameter creation automatically and correctly, which is a robust pattern seen in other MaxText tests.test_sow_functionality): This primary test verifies that when the block is run in training mode (mutable=['mtp_losses']), it correctly sows the expected number of losses and weights.test_no_sow_during_init): This test confirms that no losses are sown during the.init()call by leveraging theif not self.is_initializing()check in the application code.roll_and_maskutility was also added tomaxtext_utils_test.py.Checklist
Before submitting this PR, please make sure (put X in square brackets):