Skip to content

Yield per-document RoPE position ids from dataset#2560

Open
joecummings wants to merge 1 commit intopytorch:mainfrom
joecummings:fix-pos-id
Open

Yield per-document RoPE position ids from dataset#2560
joecummings wants to merge 1 commit intopytorch:mainfrom
joecummings:fix-pos-id

Conversation

@joecummings
Copy link
Member

@joecummings joecummings commented Mar 12, 2026

Fixes #2559

HuggingFaceTextDataset now tracks a _position_buffer alongside the existing _token_buffer. Each document's tokens get positions [0, 1, ..., doc_len-1], resetting at every document boundary. Positions are yielded as {"input": input, "positions": positions} and flow through the trainer's extra_inputs into Decoder.forward(positions=...) automatically.

Checkpoint state_dict/load_state_dict updated to persist the position buffer (BC via .get()).

Longer-term consideration

Right now there are two considerations for packed datasets: attention masks and position IDs. Attention masks are computed in the post_dataloading_process and, in this PR, position IDs are built in the dataset. Constructing masks purely based on EOS token id is fragile, especially with post-training multi-turn sequences where models could co-opts that token for end of sequence versus end of document.

The right long-term approach for torchtitan is that datasets yield seq_lens metadata alongside tokens (rather than position_ids directely), and both positions and attention masks are derived from that single source of truth in post-processing. This would retire the EOS-based get_document_mask_mod path entirely and co-locate both computations in one place.

# In dataloader
def _iter_greedy_packed(self):
      for sample in self._get_data_iter():
          input_ids = self._tokenize(sample)
          self._pack_buffer_input.extend(input_ids)
          self._pack_seq_lens.append(len(input_ids))  # just track the length

# In post dataloading process
 if "seq_lens" in extra_inputs:
          seq_lens = extra_inputs.pop("seq_lens")
          extra_inputs["positions"] = positions_from_seq_lens(seq_lens)
          extra_kwargs["attention_masks"] = mask_from_seq_lens(seq_lens)

Doesn't change how Decoder works.

Resources: https://github.com/NVIDIA/NeMo/blob/v2.7.0/nemo/collections/llm/gpt/data/core.py, https://github.com/pytorch/torchtune/blob/d0f63bb33d00b8bd3905a010b71d8c6324c2e980/torchtune/datasets/_packed.py#L108-L143,

Test plan

Unit tests pass

Also for fun, comparison between WITH position ids and WITHOUT. Definitely different in the loss, but not by a ton:
Screenshot 2026-03-12 at 5 13 11 PM

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Mar 12, 2026
@joecummings joecummings changed the title Yield per-document RoPE position IDs from HuggingFaceTextDataset Yield per-document RoPE position ids from dataset Mar 12, 2026
Fixes pytorch#2559. The dataloader now tracks a position buffer alongside the
token buffer, resetting positions to 0 at each document boundary. This
ensures RoPE encodes within-document positions correctly when
block_causal attention is used.
@joecummings
Copy link
Member Author

cc @tianyu-l @francesco-bertolotti : I did the fix that was discussed in #2559, but the "longer term fix" is also pretty simple. I might suggest we just do that in this PR, unless you have objections b/c that would technically be changing the behavior of the attention mask construction. Could be a follow up.

@joecummings joecummings marked this pull request as ready for review March 12, 2026 21:06
@tianyu-l
Copy link
Contributor

@joecummings
The long term fix sounds reasonable. It can also replace varlen metadata creation https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/common/attention.py#L322

co-locate both computations in one place

So you are suggesting putting it in dataloading. But then for more complicated, model-specific mask generation (e.g. https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama4/model.py#L209), there still need to be this post_dataloading_processing https://github.com/pytorch/torchtitan/blob/main/torchtitan/trainer.py#L608, right?

@rakkit
Copy link
Contributor

rakkit commented Mar 13, 2026

Also for fun, comparison between WITH position ids and WITHOUT. Definitely different in the loss, but not by a ton:

i think this is expected for rope

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

RoPE positions are never set

3 participants