Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 29 additions & 8 deletions nemo_rl/distributed/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,16 +105,37 @@ def backward(
) -> tuple[torch.Tensor, None, None, None, None, None, None]:
grad_output = grad_outputs[0]
softmax, target_mask, masked_target = ctx.saved_tensors
partition_vocab_size = softmax.size(-1)

# 1 if it's the chosen log prob, 0 otherwise
is_chosen = (~target_mask).unsqueeze(-1) * torch.nn.functional.one_hot(
masked_target, num_classes=partition_vocab_size
)

grad_input = is_chosen.float().sub_(softmax)
if softmax.ndim == 3:
B, S, V = softmax.shape

grad_input.mul_(grad_output.unsqueeze(dim=-1))
# skip `torch.nn.functional.one_hot`
row = (
torch.arange(B, device=softmax.device)
.view(-1, 1)
.expand(-1, S)
.reshape(-1)
)
col = torch.arange(S, device=softmax.device).expand(B, -1).reshape(-1)
flat_idx = (row * S + col) * V

flat_chosen = flat_idx.masked_select(
~target_mask.reshape(-1)
) + masked_target.masked_select(~target_mask)

# `neg` is zero-copy
grad_input = softmax.neg()
grad_input = grad_input.mul_(grad_output.unsqueeze(-1))

grad_output_selected = grad_output.masked_select(~target_mask)
grad_input.view(-1).scatter_add_(0, flat_chosen, grad_output_selected)
else:
V = softmax.size(-1)
is_chosen = (~target_mask).unsqueeze(-1) * torch.nn.functional.one_hot(
masked_target, num_classes=V
)
grad_input = is_chosen.float().sub_(softmax)
grad_input.mul_(grad_output.unsqueeze(-1))

# if you add an argument to the forward method, then you must add a corresponding None here
return grad_input, None, None, None, None, None, None
Expand Down
343 changes: 343 additions & 0 deletions tests/unit/distributed/test_model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import torch

from nemo_rl.distributed.model_utils import (
DistributedLogprob,
_compute_distributed_log_softmax,
_get_tokens_on_this_cp_rank,
allgather_cp_sharded_tensor,
from_parallel_logits_to_logprobs,
Expand Down Expand Up @@ -422,3 +424,344 @@ def test_allgather_cp_sharded_tensor(register_allgather_cp_test_actor, cp_size):

finally:
cluster.shutdown()


@ray.remote(num_gpus=1)
class DistributedLogprobTestActor:
def __init__(self, tp_size):
self.tp_size = tp_size
self.env_vars = dict(os.environ)
torch.distributed.init_process_group(backend="nccl")
self.tp_group = torch.distributed.new_group(ranks=list(range(tp_size)))

def _torch_baseline_logprob(self, full_logits, target):
"""Single-GPU PyTorch baseline implementation for comparison."""
# Compute log softmax using standard PyTorch
log_softmax = torch.nn.functional.log_softmax(full_logits, dim=-1)

# Gather log probabilities for target tokens
target_mask = target >= 0 # Valid targets (assuming -1 or similar for padding)
log_probs = torch.gather(log_softmax, -1, target.unsqueeze(-1)).squeeze(-1)
log_probs = log_probs * target_mask.float()

return log_probs

def test_distributed_logprob_forward_and_backward(self):
"""Test DistributedLogprob forward and backward passes against PyTorch baseline."""
rank = int(os.environ["RANK"])

# Test parameters
batch_size = 4
seq_len = 8
full_vocab_size = 1024
vocab_part_size = full_vocab_size // self.tp_size

# Calculate vocab partition for this rank
vocab_start_index = rank * vocab_part_size
vocab_end_index = (rank + 1) * vocab_part_size

# Create test data with fixed seed for reproducibility (same across all ranks)
torch.manual_seed(42)

# Create full logits (same on all ranks for fair comparison)
full_logits = torch.randn(
batch_size, seq_len, full_vocab_size, device="cuda", requires_grad=True
)

# Extract this rank's vocab partition
vocab_parallel_logits = (
full_logits[:, :, vocab_start_index:vocab_end_index]
.clone()
.detach()
.requires_grad_(True)
)

# Create target tokens (ensure they span across vocab partitions) - use same seed
torch.manual_seed(
43
) # Different seed for targets to ensure they span vocab partitions
target = torch.randint(0, full_vocab_size, (batch_size, seq_len), device="cuda")

# === FORWARD PASS TEST ===
# Use the same full logits for baseline computation (without gradient tracking for forward test)
baseline_logits_forward = full_logits.clone().detach()
baseline_log_probs_forward = self._torch_baseline_logprob(
baseline_logits_forward, target
)

# Compute using DistributedLogprob (forward only first)
distributed_log_probs_inference = DistributedLogprob.apply(
vocab_parallel_logits.clone().detach(), # Clone to avoid affecting backward test
target,
vocab_start_index,
vocab_end_index,
self.tp_group,
True, # inference_only=True for forward test
)

# Compare forward results
torch.testing.assert_close(
distributed_log_probs_inference,
baseline_log_probs_forward,
rtol=1e-4,
atol=1e-4,
)

forward_max_diff = torch.max(
torch.abs(distributed_log_probs_inference - baseline_log_probs_forward)
).item()

# === BACKWARD PASS TEST ===
# Compute baseline gradients - use full_logits with gradient tracking
baseline_log_probs = self._torch_baseline_logprob(full_logits, target)
baseline_loss = torch.sum(baseline_log_probs)
baseline_loss.backward()
baseline_grad = full_logits.grad[
:, :, vocab_start_index:vocab_end_index
].clone()

# Reset full_logits grad for clean comparison
full_logits.grad = None

# Compute distributed gradients
distributed_log_probs = DistributedLogprob.apply(
vocab_parallel_logits,
target,
vocab_start_index,
vocab_end_index,
self.tp_group,
False, # inference_only=False to enable backward
)

distributed_loss = torch.sum(distributed_log_probs)
distributed_loss.backward()
distributed_grad = vocab_parallel_logits.grad

# Compare gradients
torch.testing.assert_close(
distributed_grad, baseline_grad, rtol=1e-4, atol=1e-4
)

# Compare log probs again (should be same as forward test)
torch.testing.assert_close(
distributed_log_probs, baseline_log_probs, rtol=1e-4, atol=1e-4
)

grad_max_diff = torch.max(torch.abs(distributed_grad - baseline_grad)).item()
logprob_max_diff = torch.max(
torch.abs(distributed_log_probs - baseline_log_probs)
).item()

return {
"forward_max_diff": forward_max_diff,
"grad_max_diff": grad_max_diff,
"logprob_max_diff": logprob_max_diff,
}

def test_distributed_log_softmax(self):
"""Test the _compute_distributed_log_softmax function."""
rank = int(os.environ["RANK"])

# Test parameters
batch_size = 3
seq_len = 5
full_vocab_size = 256
vocab_part_size = full_vocab_size // self.tp_size

# Calculate vocab partition for this rank
vocab_start_index = rank * vocab_part_size
vocab_end_index = (rank + 1) * vocab_part_size

# Create test data with fixed seed
torch.manual_seed(42)

# Create full logits (same on all ranks for comparison)
full_logits = torch.randn(batch_size, seq_len, full_vocab_size, device="cuda")

# Extract this rank's vocab partition
vocab_parallel_logits = full_logits[
:, :, vocab_start_index:vocab_end_index
].clone()

# 1. Compute baseline log softmax
baseline_log_softmax = torch.nn.functional.log_softmax(full_logits, dim=-1)
expected_log_softmax = baseline_log_softmax[
:, :, vocab_start_index:vocab_end_index
]

# 2. Compute distributed log softmax
distributed_log_softmax = _compute_distributed_log_softmax(
vocab_parallel_logits, self.tp_group
)

# 3. Compare results
torch.testing.assert_close(
distributed_log_softmax, expected_log_softmax, rtol=1e-5, atol=1e-5
)

max_diff = torch.max(
torch.abs(distributed_log_softmax - expected_log_softmax)
).item()

return {"max_diff": max_diff}

def test_edge_cases(self):
"""Test edge cases like empty vocab partitions or extreme values."""
rank = int(os.environ["RANK"])

# Test parameters
batch_size = 2
seq_len = 3
full_vocab_size = 64
vocab_part_size = full_vocab_size // self.tp_size

vocab_start_index = rank * vocab_part_size
vocab_end_index = (rank + 1) * vocab_part_size

# Test 1: Very large logits (test numerical stability)
torch.manual_seed(42)
large_logits = (
torch.randn(batch_size, seq_len, full_vocab_size, device="cuda") * 100
) # Large values
vocab_parallel_logits = large_logits[
:, :, vocab_start_index:vocab_end_index
].clone()

torch.manual_seed(43) # Consistent seed for targets
target = torch.randint(0, full_vocab_size, (batch_size, seq_len), device="cuda")

# Should not produce NaN or Inf
log_probs = DistributedLogprob.apply(
vocab_parallel_logits,
target,
vocab_start_index,
vocab_end_index,
self.tp_group,
True,
)

assert not torch.isnan(log_probs).any(), "Log probs contain NaN"
assert not torch.isinf(log_probs).any(), "Log probs contain Inf"

# Test 2: All targets pointing to vocab index 0 (all ranks must participate)
out_of_range_target = torch.full(
(batch_size, seq_len), 0, device="cuda"
) # All point to vocab index 0

log_probs_oor = DistributedLogprob.apply(
vocab_parallel_logits,
out_of_range_target,
vocab_start_index,
vocab_end_index,
self.tp_group,
True,
)

# Compute baseline for comparison
# All ranks should see the same full logits for this test
torch.manual_seed(42) # Reset seed to match the logits generation
baseline_large_logits = (
torch.randn(batch_size, seq_len, full_vocab_size, device="cuda") * 100
)
baseline_log_probs = self._torch_baseline_logprob(
baseline_large_logits, out_of_range_target
)

# The distributed result should match the baseline
torch.testing.assert_close(
log_probs_oor, baseline_log_probs, rtol=1e-4, atol=1e-4
)


DISTRIBUTED_LOGPROB_TEST_ACTOR_FQN = (
f"{DistributedLogprobTestActor.__module__}.DistributedLogprobTestActor"
)


@pytest.fixture
def register_distributed_logprob_test_actor():
"""Register the DistributedLogprobTestActor for use in tests."""
original_registry_value = ACTOR_ENVIRONMENT_REGISTRY.get(
DISTRIBUTED_LOGPROB_TEST_ACTOR_FQN
)
ACTOR_ENVIRONMENT_REGISTRY[DISTRIBUTED_LOGPROB_TEST_ACTOR_FQN] = (
PY_EXECUTABLES.SYSTEM
)

yield DISTRIBUTED_LOGPROB_TEST_ACTOR_FQN

# Clean up registry
if DISTRIBUTED_LOGPROB_TEST_ACTOR_FQN in ACTOR_ENVIRONMENT_REGISTRY:
if original_registry_value is None:
del ACTOR_ENVIRONMENT_REGISTRY[DISTRIBUTED_LOGPROB_TEST_ACTOR_FQN]
else:
ACTOR_ENVIRONMENT_REGISTRY[DISTRIBUTED_LOGPROB_TEST_ACTOR_FQN] = (
original_registry_value
)


@pytest.mark.parametrize("tp_size", [1, 2])
def test_distributed_logprob_all_tests(
register_distributed_logprob_test_actor, tp_size
):
"""Test all DistributedLogprob functionality for a given TP size."""
# Skip if not enough GPUs
if not torch.cuda.is_available() or torch.cuda.device_count() < tp_size:
pytest.skip(
f"Not enough GPUs available. Need {tp_size}, got {torch.cuda.device_count()}"
)

cluster = RayVirtualCluster(bundle_ct_per_node_list=[tp_size], use_gpus=True)

try:
actor_fqn = register_distributed_logprob_test_actor

# Create sharding for TP
sharding = NamedSharding(layout=list(range(tp_size)), names=["tp"])
builder = RayWorkerBuilder(actor_fqn, tp_size)

worker_group = RayWorkerGroup(
cluster=cluster,
remote_worker_builder=builder,
workers_per_node=None,
sharding_annotations=sharding,
)

# Test 1: Combined Forward and Backward pass
print(f"\n=== Testing TP={tp_size}: Forward & Backward Pass ===")
futures = worker_group.run_all_workers_single_data(
"test_distributed_logprob_forward_and_backward"
)
results = ray.get(futures)
for i, result in enumerate(results):
if "forward_max_diff" in result:
print(f"Worker {i} forward max diff: {result['forward_max_diff']:.2e}")
if "grad_max_diff" in result and "logprob_max_diff" in result:
print(
f"Worker {i} gradient max diff: {result['grad_max_diff']:.2e}, "
f"logprob max diff: {result['logprob_max_diff']:.2e}"
)

# Test 2: Log softmax function
print(f"\n=== Testing TP={tp_size}: Log Softmax ===")
futures = worker_group.run_all_workers_single_data(
"test_distributed_log_softmax"
)
results = ray.get(futures)
for i, result in enumerate(results):
if "max_diff" in result:
print(
f"Worker {i} log softmax max difference: {result['max_diff']:.2e}"
)

# Test 3: Edge cases (only for TP=2)
if tp_size == 2:
print(f"\n=== Testing TP={tp_size}: Edge Cases ===")
futures = worker_group.run_all_workers_single_data("test_edge_cases")
results = ray.get(futures)
print("Edge cases test completed successfully")

worker_group.shutdown(force=True)

finally:
cluster.shutdown()
Loading