Skip to content

Commit 18ce2fc

Browse files
committed
add a unit test
Signed-off-by: ashors1 <ashors@nvidia.com>
1 parent 13cee5b commit 18ce2fc

1 file changed

Lines changed: 16 additions & 0 deletions

File tree

tests/unit/data/test_llm_message_utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from transformers import AutoTokenizer
1919

2020
from nemo_rl.data.llm_message_utils import (
21+
_validate_tensor_consistency,
2122
message_log_to_flat_messages,
2223
get_keys_from_message_log,
2324
batched_message_log_to_flat_message,
@@ -405,6 +406,21 @@ def test_get_formatted_message_log_qwen(
405406
assert actual_text == expected_text
406407

407408

409+
def test_formatted_message_log_empty_message():
410+
message_log = [
411+
{"role": "system", "content": "You are a helpful assistant."},
412+
{"role": "user", "content": ""},
413+
]
414+
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
415+
task_data_spec = TaskDataSpec(task_name="test")
416+
result = get_formatted_message_log(message_log, tokenizer, task_data_spec)
417+
flat_result = message_log_to_flat_messages(result)
418+
for k in flat_result.keys():
419+
if isinstance(flat_result[k], torch.Tensor):
420+
# make sure validate_tensor_consistency does not raise an error when one of the messages is empty
421+
_validate_tensor_consistency([flat_result[k]])
422+
423+
408424
def test_add_loss_mask_to_chat_message_log(
409425
tokenized_chat_message_log: LLMMessageLogType,
410426
):

0 commit comments

Comments
 (0)