-
Notifications
You must be signed in to change notification settings - Fork 688
Tokenizer redesign for better model-specific feature support #1082
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1082
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 93028cf with merge base 95ccf40 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
kartikayk
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Want to express my uncertainty on the TokenEncoding naming. This seems a bit unintuitive since I'd imagine the encoding to refer to something like UTF8.
Not the best, but should we consider Tokenizer (and so SentencePieceTokenizer derives from Tokenizer) and ModelTokenizer (and so Llama3Tokenizer derives from ModelTokenizer). I don't love ModelTokenizer, but with the right doc strings I think it's passable.
|
Overall I like the proposal. Still going through all the code, but the division of Also:
I understand that this is a useful thing to do, but it makes reviewing this code harder and is kinda logically distinct from the important stuff in this PR (imo). With this move, I have to look at tokenizers in their entirety; without it I can actually see the diff. |
Yeah I agree I don't like the TokenEncoding naming either, but naming SP and TT as Tokenizer and everything else as ModelTokenizer is more confusing imo... there should be a clear distinction between a model tokenizer and the base tokenizer. Maybe we could call it |
|
@joecummings huh I guess Gemma does have special tokens, but our current tokenizer does not use them. what do you think about either punting the upgrade to later or parsing the special tokens from the HF json for now and then adding them in tokenize_messages as appropriate later? Same question for Mistral |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1082 +/- ##
===========================================
+ Coverage 26.74% 67.45% +40.71%
===========================================
Files 183 191 +8
Lines 8362 8498 +136
===========================================
+ Hits 2236 5732 +3496
+ Misses 6126 2766 -3360 ☔ View full report in Codecov by Sentry. |
I opened an issue to address this later: #1118. |
ebsmothers
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A handful more comments but no huge concerns from my side. Accepting to unblock
joecummings
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is amazing work, but I see a few things that definitely need to be documented before landing.
This PR makes resuming dataset iteration from a checkpoint fast again. This performance regression comes from pytorch/torchtitan#838. In that PR, `.skip` is removed for both map-style and iterable-style datasets for correctness reasons. However, `.skip` works as expected for map-style datasets, so the change can be reverted for that case. On the other hand, for iterable-style datasets, calling `.skip` after `split_dataset_by_node` splits the number of elements to skip **across the ranks** (e.g. calling `.skip(10)` after `split_dataset_by_node(<rank>, 2)` effectively skips 5 (`10 // 2 = 5`) elements on each rank), which isn'r what we want/expect, so removing `.skip` was justified there. Still, we can make the whole thing much faster using the [`state_dict` API](https://huggingface.co/docs/datasets/v3.5.0/en/stream#save-a-dataset-checkpoint-and-resume-iteration) for iterable-style datasets, which avoids re-iterating past shards/files when resuming.
Motivation
tokenize_messagesbut this is not configurable for model-specific special tokensDesign proposal
Intuition: separate the two core APIs (encode/decode and tokenize_messages) that operate at different levels of abstraction:
We can achieve the above with
BaseTokenizerandModelTokenizer.BaseTokenizeris the base abstract interface for any base tokenization model (SP or TT) that implements encode and decodeModelTokenizeris the base abstract interface for any model-specific tokenizer that implements tokenize_messages. All models will implement their own Tokenizer class based on this interface so they can controltokenize_messageslogicThis means the SentencePieceTokenizer and TikTokenTokenizer will be refactored to separate out encode/decode and tokenize_messages logic.
And any model tokenizers would compose with the above classes:
Changelog
Llama2TokenizerandLlama3Tokenizerfor SP and TT, respectivelyGemmaTokenizerandMistralTokenizerwhich leverageSentencePieceBaseTokenizer. Thetokenize_messageslogic is identical toLlama2Tokenizer(for now, Mistral needs to be updated with v3)SentencePieceBaseTokenizerand retain special token logicDummyTokenizersince it inherited previously fromSentencePieceTokenizerTest plan
pytest testsruns successfullyPlanned follow-ups