Skip to content

Commit 970dd10

Browse files
committed
fix: change format messages to out of place
Signed-off-by: KiddoZhu <zhaochengz@nvidia.com>
1 parent 6a324e8 commit 970dd10

1 file changed

Lines changed: 9 additions & 6 deletions

File tree

nemo_reinforcer/data/llm_message_utils.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -353,14 +353,13 @@ def get_formatted_message_log(
353353
Returns:
354354
The message log with updated 'token_ids' and 'content' fields.
355355
"""
356-
cu_message = []
356+
new_message_log = []
357357
prev_formatted_message = ""
358358
template = task_data_spec.custom_template
359359

360360
for i, message in enumerate(message_log):
361-
cu_message.append(message.copy())
362361
formatted_message = tokenizer.apply_chat_template(
363-
cu_message,
362+
message_log[: i + 1],
364363
chat_template=template,
365364
add_generation_prompt=False,
366365
tokenize=False,
@@ -383,10 +382,14 @@ def get_formatted_message_log(
383382
message_chunk = message_chunk.rstrip("\n")
384383
if not message_chunk.endswith(tokenizer.eos_token):
385384
message_chunk += tokenizer.eos_token
386-
message["token_ids"] = tokenizer(
385+
386+
new_message = message.copy()
387+
new_message["token_ids"] = tokenizer(
387388
message_chunk, return_tensors="pt", add_special_tokens=False
388389
)["input_ids"][0]
389-
message["content"] = message_chunk
390+
new_message["content"] = message_chunk
391+
new_message_log.append(new_message)
392+
390393
prev_formatted_message = formatted_message
391394

392-
return message_log
395+
return new_message_log

0 commit comments

Comments
 (0)