|
20 | 20 | import torch |
21 | 21 | from transformers import AutoTokenizer |
22 | 22 |
|
| 23 | +from nemo_rl.data.llm_message_utils import batched_message_log_to_flat_message |
23 | 24 | from nemo_rl.distributed.batched_data_dict import BatchedDataDict |
24 | 25 | from nemo_rl.distributed.virtual_cluster import RayVirtualCluster |
25 | 26 | from nemo_rl.environments.games.sliding_puzzle import ( |
@@ -440,6 +441,45 @@ def test_run_multi_step_calculator_vllm(multi_step_setup_vllm): |
440 | 441 | print("\nMulti-Step Calculator VLLM Test assertions passed.") |
441 | 442 |
|
442 | 443 |
|
| 444 | +@pytest.mark.skipif( |
| 445 | + not torch.cuda.is_available() or torch.cuda.device_count() < 1, |
| 446 | + reason="VLLM test requires at least 1 GPU", |
| 447 | +) |
| 448 | +def test_max_seqlen_respected(multi_step_setup_vllm): |
| 449 | + """Tests multi-step calculator rollout with VllmGeneration.""" |
| 450 | + vllm_generation, rollout_tokenizer, task_to_env, initial_batch, rollout_cluster = ( |
| 451 | + multi_step_setup_vllm |
| 452 | + ) |
| 453 | + max_rollout_turns = initial_batch["extra_env_info"][0]["max_steps"] + 1 |
| 454 | + max_seq_len = 290 |
| 455 | + |
| 456 | + print("\nRunning multi-step calculator rollout (VLLM)...") |
| 457 | + vllm_generation.prepare_for_generation() |
| 458 | + final_batch, rollout_metrics = run_multi_turn_rollout( |
| 459 | + policy_generation=vllm_generation, |
| 460 | + input_batch=initial_batch, |
| 461 | + tokenizer=rollout_tokenizer, |
| 462 | + task_to_env=task_to_env, |
| 463 | + max_seq_len=max_seq_len, |
| 464 | + max_rollout_turns=max_rollout_turns, |
| 465 | + ) |
| 466 | + vllm_generation.finish_generation() |
| 467 | + print("Multi-step calculator rollout complete (VLLM).") |
| 468 | + |
| 469 | + # --- Assertions --- |
| 470 | + assert isinstance(final_batch, BatchedDataDict) |
| 471 | + assert "message_log" in final_batch |
| 472 | + assert "total_reward" in final_batch |
| 473 | + assert len(final_batch["message_log"]) == len(initial_batch["message_log"]) |
| 474 | + flattened_message_log, _ = batched_message_log_to_flat_message( |
| 475 | + final_batch["message_log"] |
| 476 | + ) |
| 477 | + # Check that the sequence length is respected by flattening the message log and checking the length |
| 478 | + assert len(flattened_message_log["token_ids"][0]) == max_seq_len, ( |
| 479 | + f"Sequence length {len(flattened_message_log['token_ids'][0])} is not equal to max_seq_len {max_seq_len}" |
| 480 | + ) |
| 481 | + |
| 482 | + |
443 | 483 | # --- Fixture for Sliding Puzzle Environment --- |
444 | 484 | @pytest.fixture(scope="function") |
445 | 485 | def sliding_puzzle_environment(rollout_cluster): |
|
0 commit comments