Skip to content

Integrate Multi-Token Prediction (MTP) Training objective#1837

Merged
copybara-service[bot] merged 1 commit intomainfrom
parambole/maxtext_mtp_training_obective
Jul 14, 2025
Merged

Integrate Multi-Token Prediction (MTP) Training objective#1837
copybara-service[bot] merged 1 commit intomainfrom
parambole/maxtext_mtp_training_obective

Conversation

@parambole
Copy link
Collaborator

@parambole parambole commented Jun 16, 2025

Dependency: This PR depends on and must be merged after the refactoring in Refactor: Decouple Core Transformer Blocks #1852.

PR: Multi-Token Prediction (MTP) Integration

TL;DR

  • What: This PR integrates the Multi-Token Prediction (MTP) auxiliary training objective into MaxText.
  • Why: To improve model performance and training efficiency by densifying training signals, based on the architecture described in the DeepSeek-V3 paper.
  • How: By adding a MultiTokenPredictionBlock that runs after the main decoder stack during training. It computes an additional loss term which is added to the main loss.

Detailed Description

Background and Motivation

Standard language models are trained on a next-token prediction objective. Multi-Token Prediction (MTP) enhances this by adding an auxiliary task: from each position in a sequence, the model also learns to predict several tokens into the future. This encourages the model to develop richer internal representations and can lead to significant improvements in sample efficiency and final model performance.

This implementation follows the sequential prediction model, where the prediction of token t+k+1 is causally dependent on the layer that predicted token t+k.


Architectural Changes

To integrate this feature cleanly and robustly, several key architectural changes were made:

  1. The MTP Module (layers/multi_token_prediction.py) A new file was created to house all MTP-specific logic:

    • MultiTokenPredictionLayer: A single block responsible for one step of future prediction. It normalizes its inputs, projects them, and processes them through a standard transformer layer.
    • MultiTokenPredictionBlock: This module orchestrates the entire MTP process. It contains a for loop that runs for mtp_num_layers, instantiating a unique MultiTokenPredictionLayer for each step and maintaining the sequential flow of the hidden state.
  2. Integration with Transformer (layers/models.py) The main Transformer model was modified to facilitate the MTP "side-car":

    • The Decoder's __call__ method now returns both the main_logits and the raw final_hidden_state (pre-normalization). This makes the dependency explicit.
    • The Transformer's setup method now instantiates the MultiTokenPredictionBlock, passing it the correct DecoderLayer blueprint to ensure architectural consistency.
    • The Transformer's __call__ method calls the MTPBlock only during training (model_mode == MODEL_MODE_TRAIN), explicitly passing it the dependencies it needs (final_hidden_state, shared_embedding, etc.).
  3. Loss Calculation (train.py) The auxiliary loss is aggregated without changing the Transformer's return signature by using Flax's sow mechanism:

    • The MultiTokenPredictionBlock calls self.sow('mtp_losses', 'losses', ...) for each layer's calculated loss. This is guarded by a if not self.is_initializing() check to prevent running during model initialization.
    • The main loss_fn in train.py is now responsible for "reaping" these values by making the 'mtp_losses' collection mutable during the training .apply call.
    • It then retrieves the tuple of sown losses and weights using the existing maxtext_utils.get_nested_value utility and the explicit path (mtp_losses, mtp_block, losses).
    • Finally, it computes the average MTP loss, scales it by mtp_loss_scaling_factor, and adds it to the main loss before backpropagation. The mtp_loss is also added to the training and evaluation metrics for logging.

Configuration

This feature is controlled via two new parameters in base.yml:

  • mtp_num_layers: (int, default: 0) The number of auxiliary prediction layers to use. Set to a positive integer to enable MTP.
  • mtp_loss_scaling_factor: (float, default: 0.1) The weighting factor for the final MTP loss.

How to Use: To enable MTP with 4 prediction heads and a 15% loss contribution, add the following to your config file:

YAML

mtp_num_layers: 4
mtp_loss_scaling_factor: 0.15


Testing Strategy

A new test file, MaxText/tests/multi_token_prediction_test.py, has been added with a new test class, MultiTokenPredictionBlockTest, to ensure the implementation is robust.

The testing follows these key principles:

  • Wrapper Model: A lightweight MTPBlockTestModel is used to wrap the MultiTokenPredictionBlock and its dependencies. This allows Flax's .init() to handle all parameter creation automatically and correctly, which is a robust pattern seen in other MaxText tests.
  • Core Functionality Test (test_sow_functionality): This primary test verifies that when the block is run in training mode (mutable=['mtp_losses']), it correctly sows the expected number of losses and weights.
  • Initialization Test (test_no_sow_during_init): This test confirms that no losses are sown during the .init() call by leveraging the if not self.is_initializing() check in the application code.
  • A new test for the roll_and_mask utility was also added to maxtext_utils_test.py.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

@parambole parambole force-pushed the parambole/maxtext_mtp_training_obective branch from da99edb to affadb8 Compare June 16, 2025 18:30
@parambole parambole marked this pull request as ready for review June 16, 2025 18:31
@parambole parambole force-pushed the parambole/maxtext_mtp_training_obective branch 6 times, most recently from 59ddf3f to 27ae66f Compare June 17, 2025 02:34
@parambole parambole force-pushed the parambole/maxtext_mtp_training_obective branch from ccebc8c to b1b2d95 Compare June 19, 2025 19:23
@parambole parambole changed the base branch from main to parambole/mtp_refactor June 19, 2025 19:24
@parambole parambole force-pushed the parambole/maxtext_mtp_training_obective branch 3 times, most recently from cfc3476 to d1be656 Compare July 10, 2025 03:20
Copy link
Collaborator

@gobbleturk gobbleturk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome tests!

Overall it looks great, I left a few nits/code style changes

Copy link
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we add general MTP support into announcement?

@parambole
Copy link
Collaborator Author

Shall we add general MTP support into announcement?

That is a great suggestion. I have updated the announcement section.

@parambole parambole force-pushed the parambole/maxtext_mtp_training_obective branch 5 times, most recently from cec106a to f6967f5 Compare July 10, 2025 23:43
Copy link
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! LGTM

@parambole parambole force-pushed the parambole/maxtext_mtp_training_obective branch from 4e14e74 to 4c39e27 Compare July 11, 2025 00:31
Copy link
Collaborator

@gagika gagika left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks

@parambole parambole force-pushed the parambole/maxtext_mtp_training_obective branch 7 times, most recently from 614af6d to f7fced1 Compare July 11, 2025 23:59
@parambole parambole force-pushed the parambole/maxtext_mtp_training_obective branch from f7fced1 to 4d9c654 Compare July 14, 2025 17:07
@parambole parambole force-pushed the parambole/maxtext_mtp_training_obective branch from 4d9c654 to 441e788 Compare July 14, 2025 17:11
@copybara-service copybara-service bot merged commit 35143b6 into main Jul 14, 2025
17 checks passed
@copybara-service copybara-service bot deleted the parambole/maxtext_mtp_training_obective branch July 14, 2025 18:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants