Yield per-document RoPE position ids from dataset#2560
Yield per-document RoPE position ids from dataset#2560joecummings wants to merge 1 commit intopytorch:mainfrom
Conversation
Fixes pytorch#2559. The dataloader now tracks a position buffer alongside the token buffer, resetting positions to 0 at each document boundary. This ensures RoPE encodes within-document positions correctly when block_causal attention is used.
|
cc @tianyu-l @francesco-bertolotti : I did the fix that was discussed in #2559, but the "longer term fix" is also pretty simple. I might suggest we just do that in this PR, unless you have objections b/c that would technically be changing the behavior of the attention mask construction. Could be a follow up. |
|
@joecummings
So you are suggesting putting it in dataloading. But then for more complicated, model-specific mask generation (e.g. https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama4/model.py#L209), there still need to be this post_dataloading_processing https://github.com/pytorch/torchtitan/blob/main/torchtitan/trainer.py#L608, right? |
i think this is expected for rope |
Fixes #2559
HuggingFaceTextDatasetnow tracks a_position_bufferalongside the existing_token_buffer.Each document's tokens get positions [0, 1, ..., doc_len-1], resetting at every document boundary. Positions are yielded as {"input": input, "positions": positions} and flow through the trainer'sextra_inputsintoDecoder.forward(positions=...)automatically.Checkpoint state_dict/load_state_dict updated to persist the position buffer (BC via .get()).
Longer-term consideration
Right now there are two considerations for packed datasets: attention masks and position IDs. Attention masks are computed in the post_dataloading_process and, in this PR, position IDs are built in the dataset. Constructing masks purely based on EOS token id is fragile, especially with post-training multi-turn sequences where models could co-opts that token for end of sequence versus end of document.
The right long-term approach for torchtitan is that datasets yield
seq_lensmetadata alongside tokens (rather thanposition_idsdirectely), and both positions and attention masks are derived from that single source of truth in post-processing. This would retire the EOS-basedget_document_mask_modpath entirely and co-locate both computations in one place.Doesn't change how Decoder works.
Resources: https://github.com/NVIDIA/NeMo/blob/v2.7.0/nemo/collections/llm/gpt/data/core.py, https://github.com/pytorch/torchtune/blob/d0f63bb33d00b8bd3905a010b71d8c6324c2e980/torchtune/datasets/_packed.py#L108-L143,
Test plan
Unit tests pass
Also for fun, comparison between WITH position ids and WITHOUT. Definitely different in the loss, but not by a ton:
