|
18 | 18 | from transformers import AutoTokenizer |
19 | 19 |
|
20 | 20 | from nemo_rl.data.llm_message_utils import ( |
| 21 | + _validate_tensor_consistency, |
21 | 22 | message_log_to_flat_messages, |
22 | 23 | get_keys_from_message_log, |
23 | 24 | batched_message_log_to_flat_message, |
@@ -405,6 +406,21 @@ def test_get_formatted_message_log_qwen( |
405 | 406 | assert actual_text == expected_text |
406 | 407 |
|
407 | 408 |
|
| 409 | +def test_formatted_message_log_empty_message(): |
| 410 | + message_log = [ |
| 411 | + {"role": "system", "content": "You are a helpful assistant."}, |
| 412 | + {"role": "user", "content": ""}, |
| 413 | + ] |
| 414 | + tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct") |
| 415 | + task_data_spec = TaskDataSpec(task_name="test") |
| 416 | + result = get_formatted_message_log(message_log, tokenizer, task_data_spec) |
| 417 | + flat_result = message_log_to_flat_messages(result) |
| 418 | + for k in flat_result.keys(): |
| 419 | + if isinstance(flat_result[k], torch.Tensor): |
| 420 | + # make sure validate_tensor_consistency does not raise an error when one of the messages is empty |
| 421 | + _validate_tensor_consistency([flat_result[k]]) |
| 422 | + |
| 423 | + |
408 | 424 | def test_add_loss_mask_to_chat_message_log( |
409 | 425 | tokenized_chat_message_log: LLMMessageLogType, |
410 | 426 | ): |
|
0 commit comments