Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions nemo_reinforcer/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ def grpo_train(
# Convert LLMMessageLogType to FlatMessagesType for generation
batched_flat, input_lengths = batched_message_log_to_flat_message(
repeated_batch["message_log"],
pad_value_dict={"token_ids": tokenizer.eos_token_id},
pad_value_dict={"token_ids": tokenizer.pad_token_id},
Comment thread
terrykong marked this conversation as resolved.
)
input_ids = batched_flat["token_ids"]
# Create generation-specific input structure
Expand Down Expand Up @@ -547,7 +547,7 @@ def grpo_train(
# Convert updated LLMMessageLogType to FlatMessagesType for training
flat_messages, input_lengths = batched_message_log_to_flat_message(
repeated_batch["message_log"],
pad_value_dict={"token_ids": tokenizer.eos_token_id},
pad_value_dict={"token_ids": tokenizer.pad_token_id},
)

# Create training data from flattened messages
Expand Down Expand Up @@ -704,7 +704,7 @@ def validate(
# Convert LLMMessageLogType to FlatMessagesType for generation
batched_flat, input_lengths = batched_message_log_to_flat_message(
val_batch["message_log"],
pad_value_dict={"token_ids": tokenizer.eos_token_id},
pad_value_dict={"token_ids": tokenizer.pad_token_id},
)
# Extract input IDs
input_ids = batched_flat["token_ids"]
Expand Down
4 changes: 2 additions & 2 deletions nemo_reinforcer/algorithms/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def validate(

cat_and_padded, input_lengths = batched_message_log_to_flat_message(
val_batch["message_log"],
pad_value_dict={"token_ids": tokenizer.eos_token_id},
pad_value_dict={"token_ids": tokenizer.pad_token_id},
)

val_data: BatchedDataDict = BatchedDataDict(
Expand Down Expand Up @@ -356,7 +356,7 @@ def sft_train(

cat_and_padded, input_lengths = batched_message_log_to_flat_message(
batch["message_log"],
pad_value_dict={"token_ids": tokenizer.eos_token_id},
pad_value_dict={"token_ids": tokenizer.pad_token_id},
)

train_data: BatchedDataDict = BatchedDataDict(
Expand Down
5 changes: 5 additions & 0 deletions nemo_reinforcer/models/generation/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ def verify_right_padding(
f"data must be a BatchedDataDict, got type: {type(data)}"
)

assert pad_value is not None, (
"Tokenizer does not have a pad token assigned. \n"
"If the default tokenizer does not have a pad token, you can assign it the value of eos token by tokenizer.pad_token = tokenizer.eos_token"
)

# Determine which type of data we're dealing with
if "input_ids" in data and "input_lengths" in data:
# GenerationDatumSpec
Expand Down