diff --git a/nemo_rl/distributed/model_utils.py b/nemo_rl/distributed/model_utils.py index 606fd8464b..0cde920c01 100644 --- a/nemo_rl/distributed/model_utils.py +++ b/nemo_rl/distributed/model_utils.py @@ -105,16 +105,37 @@ def backward( ) -> tuple[torch.Tensor, None, None, None, None, None, None]: grad_output = grad_outputs[0] softmax, target_mask, masked_target = ctx.saved_tensors - partition_vocab_size = softmax.size(-1) - # 1 if it's the chosen log prob, 0 otherwise - is_chosen = (~target_mask).unsqueeze(-1) * torch.nn.functional.one_hot( - masked_target, num_classes=partition_vocab_size - ) - - grad_input = is_chosen.float().sub_(softmax) + if softmax.ndim == 3: + B, S, V = softmax.shape - grad_input.mul_(grad_output.unsqueeze(dim=-1)) + # skip `torch.nn.functional.one_hot` + row = ( + torch.arange(B, device=softmax.device) + .view(-1, 1) + .expand(-1, S) + .reshape(-1) + ) + col = torch.arange(S, device=softmax.device).expand(B, -1).reshape(-1) + flat_idx = (row * S + col) * V + + flat_chosen = flat_idx.masked_select( + ~target_mask.reshape(-1) + ) + masked_target.masked_select(~target_mask) + + # `neg` is zero-copy + grad_input = softmax.neg() + grad_input = grad_input.mul_(grad_output.unsqueeze(-1)) + + grad_output_selected = grad_output.masked_select(~target_mask) + grad_input.view(-1).scatter_add_(0, flat_chosen, grad_output_selected) + else: + V = softmax.size(-1) + is_chosen = (~target_mask).unsqueeze(-1) * torch.nn.functional.one_hot( + masked_target, num_classes=V + ) + grad_input = is_chosen.float().sub_(softmax) + grad_input.mul_(grad_output.unsqueeze(-1)) # if you add an argument to the forward method, then you must add a corresponding None here return grad_input, None, None, None, None, None, None diff --git a/tests/unit/distributed/test_model_utils.py b/tests/unit/distributed/test_model_utils.py index cee92c49b0..2f8ef2011a 100644 --- a/tests/unit/distributed/test_model_utils.py +++ b/tests/unit/distributed/test_model_utils.py @@ -18,6 +18,8 @@ import torch from nemo_rl.distributed.model_utils import ( + DistributedLogprob, + _compute_distributed_log_softmax, _get_tokens_on_this_cp_rank, allgather_cp_sharded_tensor, from_parallel_logits_to_logprobs, @@ -422,3 +424,344 @@ def test_allgather_cp_sharded_tensor(register_allgather_cp_test_actor, cp_size): finally: cluster.shutdown() + + +@ray.remote(num_gpus=1) +class DistributedLogprobTestActor: + def __init__(self, tp_size): + self.tp_size = tp_size + self.env_vars = dict(os.environ) + torch.distributed.init_process_group(backend="nccl") + self.tp_group = torch.distributed.new_group(ranks=list(range(tp_size))) + + def _torch_baseline_logprob(self, full_logits, target): + """Single-GPU PyTorch baseline implementation for comparison.""" + # Compute log softmax using standard PyTorch + log_softmax = torch.nn.functional.log_softmax(full_logits, dim=-1) + + # Gather log probabilities for target tokens + target_mask = target >= 0 # Valid targets (assuming -1 or similar for padding) + log_probs = torch.gather(log_softmax, -1, target.unsqueeze(-1)).squeeze(-1) + log_probs = log_probs * target_mask.float() + + return log_probs + + def test_distributed_logprob_forward_and_backward(self): + """Test DistributedLogprob forward and backward passes against PyTorch baseline.""" + rank = int(os.environ["RANK"]) + + # Test parameters + batch_size = 4 + seq_len = 8 + full_vocab_size = 1024 + vocab_part_size = full_vocab_size // self.tp_size + + # Calculate vocab partition for this rank + vocab_start_index = rank * vocab_part_size + vocab_end_index = (rank + 1) * vocab_part_size + + # Create test data with fixed seed for reproducibility (same across all ranks) + torch.manual_seed(42) + + # Create full logits (same on all ranks for fair comparison) + full_logits = torch.randn( + batch_size, seq_len, full_vocab_size, device="cuda", requires_grad=True + ) + + # Extract this rank's vocab partition + vocab_parallel_logits = ( + full_logits[:, :, vocab_start_index:vocab_end_index] + .clone() + .detach() + .requires_grad_(True) + ) + + # Create target tokens (ensure they span across vocab partitions) - use same seed + torch.manual_seed( + 43 + ) # Different seed for targets to ensure they span vocab partitions + target = torch.randint(0, full_vocab_size, (batch_size, seq_len), device="cuda") + + # === FORWARD PASS TEST === + # Use the same full logits for baseline computation (without gradient tracking for forward test) + baseline_logits_forward = full_logits.clone().detach() + baseline_log_probs_forward = self._torch_baseline_logprob( + baseline_logits_forward, target + ) + + # Compute using DistributedLogprob (forward only first) + distributed_log_probs_inference = DistributedLogprob.apply( + vocab_parallel_logits.clone().detach(), # Clone to avoid affecting backward test + target, + vocab_start_index, + vocab_end_index, + self.tp_group, + True, # inference_only=True for forward test + ) + + # Compare forward results + torch.testing.assert_close( + distributed_log_probs_inference, + baseline_log_probs_forward, + rtol=1e-4, + atol=1e-4, + ) + + forward_max_diff = torch.max( + torch.abs(distributed_log_probs_inference - baseline_log_probs_forward) + ).item() + + # === BACKWARD PASS TEST === + # Compute baseline gradients - use full_logits with gradient tracking + baseline_log_probs = self._torch_baseline_logprob(full_logits, target) + baseline_loss = torch.sum(baseline_log_probs) + baseline_loss.backward() + baseline_grad = full_logits.grad[ + :, :, vocab_start_index:vocab_end_index + ].clone() + + # Reset full_logits grad for clean comparison + full_logits.grad = None + + # Compute distributed gradients + distributed_log_probs = DistributedLogprob.apply( + vocab_parallel_logits, + target, + vocab_start_index, + vocab_end_index, + self.tp_group, + False, # inference_only=False to enable backward + ) + + distributed_loss = torch.sum(distributed_log_probs) + distributed_loss.backward() + distributed_grad = vocab_parallel_logits.grad + + # Compare gradients + torch.testing.assert_close( + distributed_grad, baseline_grad, rtol=1e-4, atol=1e-4 + ) + + # Compare log probs again (should be same as forward test) + torch.testing.assert_close( + distributed_log_probs, baseline_log_probs, rtol=1e-4, atol=1e-4 + ) + + grad_max_diff = torch.max(torch.abs(distributed_grad - baseline_grad)).item() + logprob_max_diff = torch.max( + torch.abs(distributed_log_probs - baseline_log_probs) + ).item() + + return { + "forward_max_diff": forward_max_diff, + "grad_max_diff": grad_max_diff, + "logprob_max_diff": logprob_max_diff, + } + + def test_distributed_log_softmax(self): + """Test the _compute_distributed_log_softmax function.""" + rank = int(os.environ["RANK"]) + + # Test parameters + batch_size = 3 + seq_len = 5 + full_vocab_size = 256 + vocab_part_size = full_vocab_size // self.tp_size + + # Calculate vocab partition for this rank + vocab_start_index = rank * vocab_part_size + vocab_end_index = (rank + 1) * vocab_part_size + + # Create test data with fixed seed + torch.manual_seed(42) + + # Create full logits (same on all ranks for comparison) + full_logits = torch.randn(batch_size, seq_len, full_vocab_size, device="cuda") + + # Extract this rank's vocab partition + vocab_parallel_logits = full_logits[ + :, :, vocab_start_index:vocab_end_index + ].clone() + + # 1. Compute baseline log softmax + baseline_log_softmax = torch.nn.functional.log_softmax(full_logits, dim=-1) + expected_log_softmax = baseline_log_softmax[ + :, :, vocab_start_index:vocab_end_index + ] + + # 2. Compute distributed log softmax + distributed_log_softmax = _compute_distributed_log_softmax( + vocab_parallel_logits, self.tp_group + ) + + # 3. Compare results + torch.testing.assert_close( + distributed_log_softmax, expected_log_softmax, rtol=1e-5, atol=1e-5 + ) + + max_diff = torch.max( + torch.abs(distributed_log_softmax - expected_log_softmax) + ).item() + + return {"max_diff": max_diff} + + def test_edge_cases(self): + """Test edge cases like empty vocab partitions or extreme values.""" + rank = int(os.environ["RANK"]) + + # Test parameters + batch_size = 2 + seq_len = 3 + full_vocab_size = 64 + vocab_part_size = full_vocab_size // self.tp_size + + vocab_start_index = rank * vocab_part_size + vocab_end_index = (rank + 1) * vocab_part_size + + # Test 1: Very large logits (test numerical stability) + torch.manual_seed(42) + large_logits = ( + torch.randn(batch_size, seq_len, full_vocab_size, device="cuda") * 100 + ) # Large values + vocab_parallel_logits = large_logits[ + :, :, vocab_start_index:vocab_end_index + ].clone() + + torch.manual_seed(43) # Consistent seed for targets + target = torch.randint(0, full_vocab_size, (batch_size, seq_len), device="cuda") + + # Should not produce NaN or Inf + log_probs = DistributedLogprob.apply( + vocab_parallel_logits, + target, + vocab_start_index, + vocab_end_index, + self.tp_group, + True, + ) + + assert not torch.isnan(log_probs).any(), "Log probs contain NaN" + assert not torch.isinf(log_probs).any(), "Log probs contain Inf" + + # Test 2: All targets pointing to vocab index 0 (all ranks must participate) + out_of_range_target = torch.full( + (batch_size, seq_len), 0, device="cuda" + ) # All point to vocab index 0 + + log_probs_oor = DistributedLogprob.apply( + vocab_parallel_logits, + out_of_range_target, + vocab_start_index, + vocab_end_index, + self.tp_group, + True, + ) + + # Compute baseline for comparison + # All ranks should see the same full logits for this test + torch.manual_seed(42) # Reset seed to match the logits generation + baseline_large_logits = ( + torch.randn(batch_size, seq_len, full_vocab_size, device="cuda") * 100 + ) + baseline_log_probs = self._torch_baseline_logprob( + baseline_large_logits, out_of_range_target + ) + + # The distributed result should match the baseline + torch.testing.assert_close( + log_probs_oor, baseline_log_probs, rtol=1e-4, atol=1e-4 + ) + + +DISTRIBUTED_LOGPROB_TEST_ACTOR_FQN = ( + f"{DistributedLogprobTestActor.__module__}.DistributedLogprobTestActor" +) + + +@pytest.fixture +def register_distributed_logprob_test_actor(): + """Register the DistributedLogprobTestActor for use in tests.""" + original_registry_value = ACTOR_ENVIRONMENT_REGISTRY.get( + DISTRIBUTED_LOGPROB_TEST_ACTOR_FQN + ) + ACTOR_ENVIRONMENT_REGISTRY[DISTRIBUTED_LOGPROB_TEST_ACTOR_FQN] = ( + PY_EXECUTABLES.SYSTEM + ) + + yield DISTRIBUTED_LOGPROB_TEST_ACTOR_FQN + + # Clean up registry + if DISTRIBUTED_LOGPROB_TEST_ACTOR_FQN in ACTOR_ENVIRONMENT_REGISTRY: + if original_registry_value is None: + del ACTOR_ENVIRONMENT_REGISTRY[DISTRIBUTED_LOGPROB_TEST_ACTOR_FQN] + else: + ACTOR_ENVIRONMENT_REGISTRY[DISTRIBUTED_LOGPROB_TEST_ACTOR_FQN] = ( + original_registry_value + ) + + +@pytest.mark.parametrize("tp_size", [1, 2]) +def test_distributed_logprob_all_tests( + register_distributed_logprob_test_actor, tp_size +): + """Test all DistributedLogprob functionality for a given TP size.""" + # Skip if not enough GPUs + if not torch.cuda.is_available() or torch.cuda.device_count() < tp_size: + pytest.skip( + f"Not enough GPUs available. Need {tp_size}, got {torch.cuda.device_count()}" + ) + + cluster = RayVirtualCluster(bundle_ct_per_node_list=[tp_size], use_gpus=True) + + try: + actor_fqn = register_distributed_logprob_test_actor + + # Create sharding for TP + sharding = NamedSharding(layout=list(range(tp_size)), names=["tp"]) + builder = RayWorkerBuilder(actor_fqn, tp_size) + + worker_group = RayWorkerGroup( + cluster=cluster, + remote_worker_builder=builder, + workers_per_node=None, + sharding_annotations=sharding, + ) + + # Test 1: Combined Forward and Backward pass + print(f"\n=== Testing TP={tp_size}: Forward & Backward Pass ===") + futures = worker_group.run_all_workers_single_data( + "test_distributed_logprob_forward_and_backward" + ) + results = ray.get(futures) + for i, result in enumerate(results): + if "forward_max_diff" in result: + print(f"Worker {i} forward max diff: {result['forward_max_diff']:.2e}") + if "grad_max_diff" in result and "logprob_max_diff" in result: + print( + f"Worker {i} gradient max diff: {result['grad_max_diff']:.2e}, " + f"logprob max diff: {result['logprob_max_diff']:.2e}" + ) + + # Test 2: Log softmax function + print(f"\n=== Testing TP={tp_size}: Log Softmax ===") + futures = worker_group.run_all_workers_single_data( + "test_distributed_log_softmax" + ) + results = ray.get(futures) + for i, result in enumerate(results): + if "max_diff" in result: + print( + f"Worker {i} log softmax max difference: {result['max_diff']:.2e}" + ) + + # Test 3: Edge cases (only for TP=2) + if tp_size == 2: + print(f"\n=== Testing TP={tp_size}: Edge Cases ===") + futures = worker_group.run_all_workers_single_data("test_edge_cases") + results = ray.get(futures) + print("Edge cases test completed successfully") + + worker_group.shutdown(force=True) + + finally: + cluster.shutdown()