diff --git a/olmo/tokenizer.py b/olmo/tokenizer.py index 7f5026302..885594999 100644 --- a/olmo/tokenizer.py +++ b/olmo/tokenizer.py @@ -1,6 +1,7 @@ from __future__ import annotations import os +import inspect from pathlib import Path from typing import List, Optional, Union @@ -180,7 +181,12 @@ def encode_batch(self, inputs: List[str], add_special_tokens: bool = True) -> Li if truncate_to is not None and add_special_tokens: truncate_to -= self.num_special_tokens_to_add(False) - batch_encoding = self.base_tokenizer.encode_batch(inputs) + # Check if the base tokenizer's encode_batch method supports add_special_tokens parameter + if 'add_special_tokens' in inspect.signature(self.base_tokenizer.encode_batch).parameters: + batch_encoding = self.base_tokenizer.encode_batch(inputs, add_special_tokens=add_special_tokens) + else: + # Fallback to original behavior if the parameter isn't supported + batch_encoding = self.base_tokenizer.encode_batch(inputs) all_input_ids = [] for encoding in batch_encoding: