Skip to content

Commit 5ff10f6

Browse files
zpqiuSahilJain314
andauthored
fix: prevent division by zero in ClippedPGLossFn calculation (#166)
Signed-off-by: Zhaopeng Qiu <alexq@nvidia.com> Signed-off-by: Alex Qiu <alexq@nvidia.com> Co-authored-by: Sahil Jain <48468750+SahilJain314@users.noreply.github.com>
1 parent 6db2f7a commit 5ff10f6

3 files changed

Lines changed: 26 additions & 11 deletions

File tree

nemo_reinforcer/algorithms/loss_functions.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def __call__(
9191
mask = token_mask * sample_mask.unsqueeze(-1)
9292

9393
lp_error = torch.abs(generation_logprobs - prev_logprobs) # noqa: F841 (precommit ignore for now)
94-
mult_prob_error = ((torch.exp(lp_error) * mask).sum() / mask.sum()).item()
94+
mult_prob_error = masked_mean(torch.exp(lp_error), mask).item()
9595

9696
next_token_logits = next_token_logits[:, :-1] # Remove last position's logits
9797
next_token_logprobs = torch.nn.functional.log_softmax(next_token_logits, dim=-1)
@@ -124,13 +124,8 @@ def __call__(
124124
loss1 = -advantages * ratios
125125
loss2 = -advantages * ratios_clamped
126126

127-
if mask.sum() > 0:
128-
actor_loss = masked_mean(torch.max(loss1, loss2), mask)
129-
loss = actor_loss + kl
130-
else:
131-
# disable this update since there are no valid tokens
132-
loss = loss1.view(-1)[0] * 0
133-
127+
actor_loss = masked_mean(torch.max(loss1, loss2), mask)
128+
loss = actor_loss + kl
134129
with torch.no_grad():
135130
probs_ratio = masked_mean(ratios.detach(), mask).item()
136131
probs_ratio_clamped = masked_mean(ratios_clamped.detach(), mask).item()

nemo_reinforcer/algorithms/utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,7 @@ def wrapper(*args, **kwargs):
123123
@surpress_user_warnings
124124
def masked_mean(values, mask, dim=None):
125125
"""Masks values with mask, and computes the mean of the values using the masked values."""
126-
if dim is None:
127-
return values[mask.bool()].mean()
128-
return as_masked_tensor(values, mask.bool()).mean(dim=dim).to_tensor(torch.nan)
126+
return (values * mask).sum(dim=dim) / (mask.sum(dim=dim) + 1e-8)
129127

130128

131129
def set_seed(seed: int):

tests/unit/algorithms/test_loss_functions.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,3 +386,25 @@ def test_clipped_pg_loss_zero_mask():
386386

387387
# Loss should be exactly zero
388388
torch.testing.assert_close(loss, torch.tensor(0.0, device=device))
389+
390+
391+
def test_masked_mean_all_zeros():
392+
"""Test masked_mean function with all zeros mask."""
393+
values = torch.tensor([1.0, 2.0, 3.0, 4.0])
394+
mask = torch.zeros_like(values)
395+
396+
# All zeros mask should return 0
397+
result = masked_mean(values, mask)
398+
print(result)
399+
torch.testing.assert_allclose(result, torch.tensor(0.0))
400+
401+
# With check_zero_mask=False
402+
mask[0] = 1
403+
result = masked_mean(values, mask)
404+
torch.testing.assert_allclose(result, torch.tensor(1.0))
405+
406+
# Case 2: dim is not None
407+
values = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
408+
mask = torch.zeros_like(values)
409+
result = masked_mean(values, mask, dim=1)
410+
torch.testing.assert_allclose(result, torch.tensor([0.0, 0.0]))

0 commit comments

Comments
 (0)