diff --git a/ch02/05_bpe-from-scratch/bpe-from-scratch.ipynb b/ch02/05_bpe-from-scratch/bpe-from-scratch.ipynb index 9ff3accc0..847ab0db8 100644 --- a/ch02/05_bpe-from-scratch/bpe-from-scratch.ipynb +++ b/ch02/05_bpe-from-scratch/bpe-from-scratch.ipynb @@ -416,23 +416,14 @@ " allowed_special (set): A set of special tokens to include.\n", " \"\"\"\n", "\n", - " # Preprocess: Replace spaces with \"Ġ\"\n", - " # Note that Ġ is a particularity of the GPT-2 BPE implementation\n", - " # E.g., \"Hello world\" might be tokenized as [\"Hello\", \"Ġworld\"]\n", - " # (GPT-4 BPE would tokenize it as [\"Hello\", \" world\"])\n", - " processed_text = []\n", - " for i, char in enumerate(text):\n", - " if char == \" \" and i != 0:\n", - " processed_text.append(\"Ġ\")\n", - " if char != \" \":\n", - " processed_text.append(char)\n", - " processed_text = \"\".join(processed_text)\n", + " # Pre-tokenize training text using the same boundary rules as encode()\n", + " tokens = self.pretokenize_text(text)\n", "\n", " # Initialize vocab with unique characters, including \"Ġ\" if present\n", " # Start with the first 256 ASCII characters\n", " unique_chars = [chr(i) for i in range(256)]\n", " unique_chars.extend(\n", - " char for char in sorted(set(processed_text))\n", + " char for char in sorted({char for token in tokens for char in token})\n", " if char not in unique_chars\n", " )\n", " if \"Ġ\" not in unique_chars:\n", @@ -449,15 +440,18 @@ " self.vocab[new_id] = token\n", " self.inverse_vocab[token] = new_id\n", "\n", - " # Tokenize the processed_text into token IDs\n", - " token_ids = [self.inverse_vocab[char] for char in processed_text]\n", + " # Tokenize each pre-token into character IDs\n", + " token_id_sequences = [\n", + " [self.inverse_vocab[char] for char in token]\n", + " for token in tokens\n", + " ]\n", "\n", " # BPE steps 1-3: Repeatedly find and replace frequent pairs\n", " for new_id in range(len(self.vocab), vocab_size):\n", - " pair_id = self.find_freq_pair(token_ids, mode=\"most\")\n", + " pair_id = self.find_freq_pair(token_id_sequences, mode=\"most\")\n", " if pair_id is None:\n", " break\n", - " token_ids = self.replace_pair(token_ids, pair_id, new_id)\n", + " token_id_sequences = self.replace_pair(token_id_sequences, pair_id, new_id)\n", " self.bpe_merges[pair_id] = new_id\n", "\n", " # Build the vocabulary with merged tokens\n", @@ -581,43 +575,7 @@ "\n", " \n", " # ---- Newline and carriage return handling ----\n", - " tokens = []\n", - " parts = re.split(r'(\\r\\n|\\r|\\n)', text)\n", - " for part in parts:\n", - " if part == \"\":\n", - " continue\n", - " if part == \"\\r\\n\":\n", - " tokens.append(\"\\r\")\n", - " tokens.append(\"\\n\")\n", - " continue\n", - " if part == \"\\r\":\n", - " tokens.append(\"\\r\")\n", - " continue\n", - " if part == \"\\n\":\n", - " tokens.append(\"\\n\")\n", - " continue\n", - " \n", - " # Normal chunk without line breaks:\n", - " # - If spaces precede a word, prefix the first word with 'Ġ' and\n", - " # add standalone 'Ġ' for additional spaces\n", - " # - If spaces trail the chunk (e.g., before a newline) add\n", - " # standalone 'Ġ' tokens (tiktoken produces id 220 for 'Ġ')\n", - " pending_spaces = 0\n", - " for m in re.finditer(r'( +)|(\\S+)', part):\n", - " if m.group(1) is not None:\n", - " pending_spaces += len(m.group(1))\n", - " else:\n", - " word = m.group(2)\n", - " if pending_spaces > 0:\n", - " for _ in range(pending_spaces - 1):\n", - " tokens.append(\"Ġ\") # remaining spaces as standalone\n", - " tokens.append(\"Ġ\" + word) # one leading space\n", - " pending_spaces = 0\n", - " else:\n", - " tokens.append(word)\n", - " # Trailing spaces (no following word): add standalone 'Ġ' tokens\n", - " for _ in range(pending_spaces):\n", - " tokens.append(\"Ġ\")\n", + " tokens = self.pretokenize_text(text)\n", " # ---------------------------------------------------------------\n", " \n", " # Map tokens -> ids (BPE if needed)\n", @@ -786,8 +744,53 @@ " return self.inverse_vocab.get(token, None)\n", "\n", " @staticmethod\n", - " def find_freq_pair(token_ids, mode=\"most\"):\n", - " pairs = Counter(zip(token_ids, token_ids[1:]))\n", + " def pretokenize_text(text):\n", + " tokens = []\n", + " parts = re.split(r'(\\r\\n|\\r|\\n)', text)\n", + " for part in parts:\n", + " if part == \"\":\n", + " continue\n", + " if part == \"\\r\\n\":\n", + " tokens.append(\"\\r\")\n", + " tokens.append(\"\\n\")\n", + " continue\n", + " if part == \"\\r\":\n", + " tokens.append(\"\\r\")\n", + " continue\n", + " if part == \"\\n\":\n", + " tokens.append(\"\\n\")\n", + " continue\n", + "\n", + " # Normal chunk without line breaks:\n", + " # - If spaces precede a word, prefix the first word with 'Ġ' and\n", + " # add standalone 'Ġ' for additional spaces\n", + " # - If spaces trail the chunk (e.g., before a newline) add\n", + " # standalone 'Ġ' tokens (tiktoken produces id 220 for 'Ġ')\n", + " pending_spaces = 0\n", + " for m in re.finditer(r'( +)|(\\S+)', part):\n", + " if m.group(1) is not None:\n", + " pending_spaces += len(m.group(1))\n", + " else:\n", + " word = m.group(2)\n", + " if pending_spaces > 0:\n", + " for _ in range(pending_spaces - 1):\n", + " tokens.append(\"Ġ\") # remaining spaces as standalone\n", + " tokens.append(\"Ġ\" + word) # one leading space\n", + " pending_spaces = 0\n", + " else:\n", + " tokens.append(word)\n", + " # Trailing spaces (no following word): add standalone 'Ġ' tokens\n", + " for _ in range(pending_spaces):\n", + " tokens.append(\"Ġ\")\n", + " return tokens\n", + "\n", + " @staticmethod\n", + " def find_freq_pair(token_id_sequences, mode=\"most\"):\n", + " pairs = Counter(\n", + " pair\n", + " for token_ids in token_id_sequences\n", + " for pair in zip(token_ids, token_ids[1:])\n", + " )\n", "\n", " if not pairs:\n", " return None\n", @@ -800,20 +803,25 @@ " raise ValueError(\"Invalid mode. Choose 'most' or 'least'.\")\n", "\n", " @staticmethod\n", - " def replace_pair(token_ids, pair_id, new_id):\n", - " dq = deque(token_ids)\n", - " replaced = []\n", - "\n", - " while dq:\n", - " current = dq.popleft()\n", - " if dq and (current, dq[0]) == pair_id:\n", - " replaced.append(new_id)\n", - " # Remove the 2nd token of the pair, 1st was already removed\n", - " dq.popleft()\n", - " else:\n", - " replaced.append(current)\n", + " def replace_pair(token_id_sequences, pair_id, new_id):\n", + " replaced_sequences = []\n", + "\n", + " for token_ids in token_id_sequences:\n", + " dq = deque(token_ids)\n", + " replaced = []\n", + "\n", + " while dq:\n", + " current = dq.popleft()\n", + " if dq and (current, dq[0]) == pair_id:\n", + " replaced.append(new_id)\n", + " # Remove the 2nd token of the pair, 1st was already removed\n", + " dq.popleft()\n", + " else:\n", + " replaced.append(current)\n", + "\n", + " replaced_sequences.append(replaced)\n", "\n", - " return replaced" + " return replaced_sequences" ] }, { diff --git a/ch02/05_bpe-from-scratch/tests.py b/ch02/05_bpe-from-scratch/tests.py index 8383842ba..cea966a5d 100644 --- a/ch02/05_bpe-from-scratch/tests.py +++ b/ch02/05_bpe-from-scratch/tests.py @@ -88,8 +88,14 @@ def test_tokenizer_training(imported_module, verdict_file): assert len(tokenizer.bpe_merges) == 742, "Tokenizer BPE merges count mismatch." input_text = "Jack embraced beauty through art and life." + invalid_whitespace_tokens = [ + tok for tok in tokenizer.vocab.values() + if "Ġ" in tok and tok != "Ġ" and not tok.startswith("Ġ") + ] + assert not invalid_whitespace_tokens, "Training should not learn tokens with non-leading Ġ markers." + token_ids = tokenizer.encode(input_text) - assert token_ids == [424, 256, 654, 531, 302, 311, 256, 296, 97, 465, 121, 595, 841, 116, 287, 466, 256, 326, 972, 46], "Token IDs do not match expected output." + assert token_ids == [74, 361, 310, 109, 98, 420, 397, 100, 300, 428, 116, 121, 519, 699, 299, 808, 534], "Token IDs do not match expected output." assert tokenizer.decode(token_ids) == input_text, "Decoded text does not match the original input."