Skip to content

Commit 04f30bb

Browse files
authored
fix: Fixed max seqlen not respected correctly (#299)
Signed-off-by: Sahil Jain <sahilj@nvidia.com>
1 parent daac5d9 commit 04f30bb

2 files changed

Lines changed: 43 additions & 1 deletion

File tree

nemo_rl/experience/rollouts.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,9 @@ def run_multi_turn_rollout(
311311
>= max_seq_len
312312
):
313313
# truncate
314-
tokenized_obs = tokenized_obs[: max_seq_len - active_input_lengths[i]]
314+
tokenized_obs = tokenized_obs[
315+
: max_seq_len - (len(generated_ids[i]) + active_input_lengths[i])
316+
]
315317
truncation_mask[i] = True
316318
# Record truncation
317319
sample_truncated[active_indices[i]] = True

tests/unit/experience/test_rollouts.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import torch
2121
from transformers import AutoTokenizer
2222

23+
from nemo_rl.data.llm_message_utils import batched_message_log_to_flat_message
2324
from nemo_rl.distributed.batched_data_dict import BatchedDataDict
2425
from nemo_rl.distributed.virtual_cluster import RayVirtualCluster
2526
from nemo_rl.environments.games.sliding_puzzle import (
@@ -440,6 +441,45 @@ def test_run_multi_step_calculator_vllm(multi_step_setup_vllm):
440441
print("\nMulti-Step Calculator VLLM Test assertions passed.")
441442

442443

444+
@pytest.mark.skipif(
445+
not torch.cuda.is_available() or torch.cuda.device_count() < 1,
446+
reason="VLLM test requires at least 1 GPU",
447+
)
448+
def test_max_seqlen_respected(multi_step_setup_vllm):
449+
"""Tests multi-step calculator rollout with VllmGeneration."""
450+
vllm_generation, rollout_tokenizer, task_to_env, initial_batch, rollout_cluster = (
451+
multi_step_setup_vllm
452+
)
453+
max_rollout_turns = initial_batch["extra_env_info"][0]["max_steps"] + 1
454+
max_seq_len = 290
455+
456+
print("\nRunning multi-step calculator rollout (VLLM)...")
457+
vllm_generation.prepare_for_generation()
458+
final_batch, rollout_metrics = run_multi_turn_rollout(
459+
policy_generation=vllm_generation,
460+
input_batch=initial_batch,
461+
tokenizer=rollout_tokenizer,
462+
task_to_env=task_to_env,
463+
max_seq_len=max_seq_len,
464+
max_rollout_turns=max_rollout_turns,
465+
)
466+
vllm_generation.finish_generation()
467+
print("Multi-step calculator rollout complete (VLLM).")
468+
469+
# --- Assertions ---
470+
assert isinstance(final_batch, BatchedDataDict)
471+
assert "message_log" in final_batch
472+
assert "total_reward" in final_batch
473+
assert len(final_batch["message_log"]) == len(initial_batch["message_log"])
474+
flattened_message_log, _ = batched_message_log_to_flat_message(
475+
final_batch["message_log"]
476+
)
477+
# Check that the sequence length is respected by flattening the message log and checking the length
478+
assert len(flattened_message_log["token_ids"][0]) == max_seq_len, (
479+
f"Sequence length {len(flattened_message_log['token_ids'][0])} is not equal to max_seq_len {max_seq_len}"
480+
)
481+
482+
443483
# --- Fixture for Sliding Puzzle Environment ---
444484
@pytest.fixture(scope="function")
445485
def sliding_puzzle_environment(rollout_cluster):

0 commit comments

Comments
 (0)