From 5669216b22822803f6e8dbd21bb498a979805d37 Mon Sep 17 00:00:00 2001 From: Parth Chadha Date: Wed, 2 Apr 2025 11:19:00 -0700 Subject: [PATCH 1/2] fix: ensure that we check for pad_token and not assume pad_token==eos_token Signed-off-by: Parth Chadha --- nemo_reinforcer/algorithms/grpo.py | 6 +++--- nemo_reinforcer/algorithms/sft.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/nemo_reinforcer/algorithms/grpo.py b/nemo_reinforcer/algorithms/grpo.py index e08c848522..d49dc32418 100644 --- a/nemo_reinforcer/algorithms/grpo.py +++ b/nemo_reinforcer/algorithms/grpo.py @@ -472,7 +472,7 @@ def grpo_train( # Convert LLMMessageLogType to FlatMessagesType for generation batched_flat, input_lengths = batched_message_log_to_flat_message( repeated_batch["message_log"], - pad_value_dict={"token_ids": tokenizer.eos_token_id}, + pad_value_dict={"token_ids": tokenizer.pad_token_id}, ) input_ids = batched_flat["token_ids"] # Create generation-specific input structure @@ -547,7 +547,7 @@ def grpo_train( # Convert updated LLMMessageLogType to FlatMessagesType for training flat_messages, input_lengths = batched_message_log_to_flat_message( repeated_batch["message_log"], - pad_value_dict={"token_ids": tokenizer.eos_token_id}, + pad_value_dict={"token_ids": tokenizer.pad_token_id}, ) # Create training data from flattened messages @@ -704,7 +704,7 @@ def validate( # Convert LLMMessageLogType to FlatMessagesType for generation batched_flat, input_lengths = batched_message_log_to_flat_message( val_batch["message_log"], - pad_value_dict={"token_ids": tokenizer.eos_token_id}, + pad_value_dict={"token_ids": tokenizer.pad_token_id}, ) # Extract input IDs input_ids = batched_flat["token_ids"] diff --git a/nemo_reinforcer/algorithms/sft.py b/nemo_reinforcer/algorithms/sft.py index 5ff77e11e9..8f9e34f9da 100644 --- a/nemo_reinforcer/algorithms/sft.py +++ b/nemo_reinforcer/algorithms/sft.py @@ -243,7 +243,7 @@ def validate( cat_and_padded, input_lengths = batched_message_log_to_flat_message( val_batch["message_log"], - pad_value_dict={"token_ids": tokenizer.eos_token_id}, + pad_value_dict={"token_ids": tokenizer.pad_token_id}, ) val_data: BatchedDataDict = BatchedDataDict( @@ -356,7 +356,7 @@ def sft_train( cat_and_padded, input_lengths = batched_message_log_to_flat_message( batch["message_log"], - pad_value_dict={"token_ids": tokenizer.eos_token_id}, + pad_value_dict={"token_ids": tokenizer.pad_token_id}, ) train_data: BatchedDataDict = BatchedDataDict( From 2983ca75b6428e3f7e7acacf770a39eb184cba3e Mon Sep 17 00:00:00 2001 From: Parth Chadha Date: Wed, 2 Apr 2025 11:34:01 -0700 Subject: [PATCH 2/2] Add an assertion for checking if tokenizer has pad_token_id assigned Signed-off-by: Parth Chadha --- nemo_reinforcer/models/generation/interfaces.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/nemo_reinforcer/models/generation/interfaces.py b/nemo_reinforcer/models/generation/interfaces.py index 138b70fbc1..da7e737784 100644 --- a/nemo_reinforcer/models/generation/interfaces.py +++ b/nemo_reinforcer/models/generation/interfaces.py @@ -44,6 +44,11 @@ def verify_right_padding( f"data must be a BatchedDataDict, got type: {type(data)}" ) + assert pad_value is not None, ( + "Tokenizer does not have a pad token assigned. \n" + "If the default tokenizer does not have a pad token, you can assign it the value of eos token by tokenizer.pad_token = tokenizer.eos_token" + ) + # Determine which type of data we're dealing with if "input_ids" in data and "input_lengths" in data: # GenerationDatumSpec