Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 76 additions & 68 deletions ch02/05_bpe-from-scratch/bpe-from-scratch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -416,23 +416,14 @@
" allowed_special (set): A set of special tokens to include.\n",
" \"\"\"\n",
"\n",
" # Preprocess: Replace spaces with \"Ġ\"\n",
" # Note that Ġ is a particularity of the GPT-2 BPE implementation\n",
" # E.g., \"Hello world\" might be tokenized as [\"Hello\", \"Ġworld\"]\n",
" # (GPT-4 BPE would tokenize it as [\"Hello\", \" world\"])\n",
" processed_text = []\n",
" for i, char in enumerate(text):\n",
" if char == \" \" and i != 0:\n",
" processed_text.append(\"Ġ\")\n",
" if char != \" \":\n",
" processed_text.append(char)\n",
" processed_text = \"\".join(processed_text)\n",
" # Pre-tokenize training text using the same boundary rules as encode()\n",
" tokens = self.pretokenize_text(text)\n",
"\n",
" # Initialize vocab with unique characters, including \"Ġ\" if present\n",
" # Start with the first 256 ASCII characters\n",
" unique_chars = [chr(i) for i in range(256)]\n",
" unique_chars.extend(\n",
" char for char in sorted(set(processed_text))\n",
" char for char in sorted({char for token in tokens for char in token})\n",
" if char not in unique_chars\n",
" )\n",
" if \"Ġ\" not in unique_chars:\n",
Expand All @@ -449,15 +440,18 @@
" self.vocab[new_id] = token\n",
" self.inverse_vocab[token] = new_id\n",
"\n",
" # Tokenize the processed_text into token IDs\n",
" token_ids = [self.inverse_vocab[char] for char in processed_text]\n",
" # Tokenize each pre-token into character IDs\n",
" token_id_sequences = [\n",
" [self.inverse_vocab[char] for char in token]\n",
" for token in tokens\n",
" ]\n",
"\n",
" # BPE steps 1-3: Repeatedly find and replace frequent pairs\n",
" for new_id in range(len(self.vocab), vocab_size):\n",
" pair_id = self.find_freq_pair(token_ids, mode=\"most\")\n",
" pair_id = self.find_freq_pair(token_id_sequences, mode=\"most\")\n",
" if pair_id is None:\n",
" break\n",
" token_ids = self.replace_pair(token_ids, pair_id, new_id)\n",
" token_id_sequences = self.replace_pair(token_id_sequences, pair_id, new_id)\n",
" self.bpe_merges[pair_id] = new_id\n",
"\n",
" # Build the vocabulary with merged tokens\n",
Expand Down Expand Up @@ -581,43 +575,7 @@
"\n",
" \n",
" # ---- Newline and carriage return handling ----\n",
" tokens = []\n",
" parts = re.split(r'(\\r\\n|\\r|\\n)', text)\n",
" for part in parts:\n",
" if part == \"\":\n",
" continue\n",
" if part == \"\\r\\n\":\n",
" tokens.append(\"\\r\")\n",
" tokens.append(\"\\n\")\n",
" continue\n",
" if part == \"\\r\":\n",
" tokens.append(\"\\r\")\n",
" continue\n",
" if part == \"\\n\":\n",
" tokens.append(\"\\n\")\n",
" continue\n",
" \n",
" # Normal chunk without line breaks:\n",
" # - If spaces precede a word, prefix the first word with 'Ġ' and\n",
" # add standalone 'Ġ' for additional spaces\n",
" # - If spaces trail the chunk (e.g., before a newline) add\n",
" # standalone 'Ġ' tokens (tiktoken produces id 220 for 'Ġ')\n",
" pending_spaces = 0\n",
" for m in re.finditer(r'( +)|(\\S+)', part):\n",
" if m.group(1) is not None:\n",
" pending_spaces += len(m.group(1))\n",
" else:\n",
" word = m.group(2)\n",
" if pending_spaces > 0:\n",
" for _ in range(pending_spaces - 1):\n",
" tokens.append(\"Ġ\") # remaining spaces as standalone\n",
" tokens.append(\"Ġ\" + word) # one leading space\n",
" pending_spaces = 0\n",
" else:\n",
" tokens.append(word)\n",
" # Trailing spaces (no following word): add standalone 'Ġ' tokens\n",
" for _ in range(pending_spaces):\n",
" tokens.append(\"Ġ\")\n",
" tokens = self.pretokenize_text(text)\n",
" # ---------------------------------------------------------------\n",
" \n",
" # Map tokens -> ids (BPE if needed)\n",
Expand Down Expand Up @@ -786,8 +744,53 @@
" return self.inverse_vocab.get(token, None)\n",
"\n",
" @staticmethod\n",
" def find_freq_pair(token_ids, mode=\"most\"):\n",
" pairs = Counter(zip(token_ids, token_ids[1:]))\n",
" def pretokenize_text(text):\n",
" tokens = []\n",
" parts = re.split(r'(\\r\\n|\\r|\\n)', text)\n",
" for part in parts:\n",
" if part == \"\":\n",
" continue\n",
" if part == \"\\r\\n\":\n",
" tokens.append(\"\\r\")\n",
" tokens.append(\"\\n\")\n",
" continue\n",
" if part == \"\\r\":\n",
" tokens.append(\"\\r\")\n",
" continue\n",
" if part == \"\\n\":\n",
" tokens.append(\"\\n\")\n",
" continue\n",
"\n",
" # Normal chunk without line breaks:\n",
" # - If spaces precede a word, prefix the first word with 'Ġ' and\n",
" # add standalone 'Ġ' for additional spaces\n",
" # - If spaces trail the chunk (e.g., before a newline) add\n",
" # standalone 'Ġ' tokens (tiktoken produces id 220 for 'Ġ')\n",
" pending_spaces = 0\n",
" for m in re.finditer(r'( +)|(\\S+)', part):\n",
" if m.group(1) is not None:\n",
" pending_spaces += len(m.group(1))\n",
" else:\n",
" word = m.group(2)\n",
" if pending_spaces > 0:\n",
" for _ in range(pending_spaces - 1):\n",
" tokens.append(\"Ġ\") # remaining spaces as standalone\n",
" tokens.append(\"Ġ\" + word) # one leading space\n",
" pending_spaces = 0\n",
" else:\n",
" tokens.append(word)\n",
" # Trailing spaces (no following word): add standalone 'Ġ' tokens\n",
" for _ in range(pending_spaces):\n",
" tokens.append(\"Ġ\")\n",
" return tokens\n",
"\n",
" @staticmethod\n",
" def find_freq_pair(token_id_sequences, mode=\"most\"):\n",
" pairs = Counter(\n",
" pair\n",
" for token_ids in token_id_sequences\n",
" for pair in zip(token_ids, token_ids[1:])\n",
" )\n",
"\n",
" if not pairs:\n",
" return None\n",
Expand All @@ -800,20 +803,25 @@
" raise ValueError(\"Invalid mode. Choose 'most' or 'least'.\")\n",
"\n",
" @staticmethod\n",
" def replace_pair(token_ids, pair_id, new_id):\n",
" dq = deque(token_ids)\n",
" replaced = []\n",
"\n",
" while dq:\n",
" current = dq.popleft()\n",
" if dq and (current, dq[0]) == pair_id:\n",
" replaced.append(new_id)\n",
" # Remove the 2nd token of the pair, 1st was already removed\n",
" dq.popleft()\n",
" else:\n",
" replaced.append(current)\n",
" def replace_pair(token_id_sequences, pair_id, new_id):\n",
" replaced_sequences = []\n",
"\n",
" for token_ids in token_id_sequences:\n",
" dq = deque(token_ids)\n",
" replaced = []\n",
"\n",
" while dq:\n",
" current = dq.popleft()\n",
" if dq and (current, dq[0]) == pair_id:\n",
" replaced.append(new_id)\n",
" # Remove the 2nd token of the pair, 1st was already removed\n",
" dq.popleft()\n",
" else:\n",
" replaced.append(current)\n",
"\n",
" replaced_sequences.append(replaced)\n",
"\n",
" return replaced"
" return replaced_sequences"
]
},
{
Expand Down
8 changes: 7 additions & 1 deletion ch02/05_bpe-from-scratch/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,14 @@ def test_tokenizer_training(imported_module, verdict_file):
assert len(tokenizer.bpe_merges) == 742, "Tokenizer BPE merges count mismatch."

input_text = "Jack embraced beauty through art and life."
invalid_whitespace_tokens = [
tok for tok in tokenizer.vocab.values()
if "Ġ" in tok and tok != "Ġ" and not tok.startswith("Ġ")
]
assert not invalid_whitespace_tokens, "Training should not learn tokens with non-leading Ġ markers."

token_ids = tokenizer.encode(input_text)
assert token_ids == [424, 256, 654, 531, 302, 311, 256, 296, 97, 465, 121, 595, 841, 116, 287, 466, 256, 326, 972, 46], "Token IDs do not match expected output."
assert token_ids == [74, 361, 310, 109, 98, 420, 397, 100, 300, 428, 116, 121, 519, 699, 299, 808, 534], "Token IDs do not match expected output."

assert tokenizer.decode(token_ids) == input_text, "Decoded text does not match the original input."

Expand Down