Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 39 additions & 6 deletions docs/guides/grpo.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,19 +99,19 @@ dataset = AllTaskProcessedDataset(

Ensure you provide a mapping of tasks to their processors so the dataset knows which processor to use when handling samples.

### Policy Model
## Policy Model

We define a {py:class}`PolicyInterface]() <nemo_rl.models.interfaces>` that contains everything you need to train a Policy model.

This Policy object holds a [RayWorkerGroup](../../nemo_rl/distributed/worker_groups.py) of SPMD (1 proc/gpu) processes that run HF/MCore, all coordinated by this object so it appears to you like 1 GPU!

### Fast Generation
## Fast Generation

We support vLLM through the [VllmGeneration](../../nemo_rl/models/generation/vllm.py) class right now.

The function [grpo_train](../../nemo_rl/algorithms/grpo.py) contains the core GRPO training loop.

### Loss
## Loss
We use the [ClippedPGLossFn](../../nemo_rl/algorithms/loss_functions.py) to calculate the loss for GRPO. Formally,

$$
Expand Down Expand Up @@ -141,9 +141,9 @@ where:
usually set as 3 empirically
- $r_t(\theta)$ is the ratio $\frac{\pi_\theta(x)}{\pi_{\theta_{\text{old}}}(x)}$ that measures how much the policy has change

#### Improvements to the GRPO loss formulation for stability and accuracy
### Improvements to the GRPO loss formulation for stability and accuracy

#### On-Policy KL Approximation
#### On-Policy KL Approximation (use_on_policy_kl_approximation)

In practice, we calculate the KL divergence using the estimator from Schulman 2020 (http://joschu.net/blog/kl-approx.html), which is unbiased and guaranteed to be positive.

Expand All @@ -165,7 +165,7 @@ $$
To enable the on-policy KL approximation, set the config `use_on_policy_kl_approximation=True` in the `ClippedPGLossConfig`. By default, we set this config to False to align with standard GRPO.


#### Importance Sampling Correction
#### Importance Sampling Correction (use_importance_sampling_correction)
The policy we use to draw samples, $\pi_{\theta_{\text{old}}}$, is used in both the inference framework and the training framework. To account for this distinction, we refer to the inference framework policy as $\pi_{\text{inference}}$ and the training framework policy as $\pi_{\text{training}}$. As noted in [Adding New Models](../adding-new-models.md#understand-discrepancies-between-backends), it is possible for the token probabilities from $\pi_{\text{training}}$ and $\pi_{\text{inference}}$ to have discrepancies (from numerics, precision differences, bugs, etc.), leading to off-policy samples. We can correct for this by introducing importance weights between $\pi_{\text{training}}$ and $\pi_{\text{inference}}$ to the first term of the loss function.

Let $f_\theta(x) = \min \Big(\frac{\pi_\theta(x)}{\pi_{\theta_{\text{old}}}(x)}A_t, \text{clip} \big( \frac{\pi_\theta(x)}{\pi_{\theta_{\text{old}}}(x)}, 1 - \varepsilon, 1 + \varepsilon \big) A_t \Big)$ represent the first term of loss function. Then,
Expand All @@ -182,6 +182,39 @@ By multiplying the first term of the loss function by the importance weights $\f

To enable the importance sampling correction, set the config `use_importance_sampling_correction=True` in the `ClippedPGLossConfig`. By default, we set this config to False to align with standard GRPO.


## Metrics ({wandb, tb}_name)
We track a few metrics during training for scientific experimentation and to validate correctness as the run progresses.

### Multiplicative Token Probability Error (token_mult_prob_error)
This is equal to the 'Logprob consistency metric' defined in [Adding New Models](../adding-new-models.md#importance-of-log-probability-consistency-in-training-and-inference):

$$
\text{token-mult-prob-error} = \frac{1}{n}\sum_{i=1}^{n\text{(tokens)}}\exp\left(\left\|\text{log-train-fwk}_i - \text{logprobs-inference-fwk}_i\right\|\right)
$$

Intuitively, this measures the average multiplicative probability error for sampled tokens, where samples are drawn as $x \sim \pi_{\text{inference-framework}}$. The purpose of this is to highlight any obvious sampling errors or discrepencies between the inference backend and training framework. If it trends upward steeply over the course of training past $\sim 1-2\%$, there is usually a problem with how your weights are being updated. If very spiky, it can indicate a bug in the inference framework or buggy weight refitting.

### Sampling Importance Ratio (sampling_importance_ratio)
Not to be confused with the clipped importance ratio in PPO/GRPO, this is the importance ratio between $\pi_{\text{training}}$ and $\pi_{\text{inference}}$.

This is simply $\frac{1}{|T|}\sum_{t \in \text{tokens}}\text{exp}(\text{log}(\pi_{\text{training}}(t)) - \text{log}(\pi_{\text{inference}}(t)))$

Similar to [Multiplicative Token Probability Error](#multiplicative-token-probability-error-token_mult_prob_error), this is a measure of how far off your inference backend is from your training framework. However, this metric is meant to find the bias in that error instead of loosely the variance as it does not take the absolute value of the error. With some noise, this should hover around 1.

This metric is always calculated and the per-token version (without the mean) is used in the loss function when [Importance Sampling Correction](#importance-sampling-correction-use_importance_sampling_correction) is enabled.

### Entropy (approx_entropy)
We roughly approximate the entropy of the LLM's distribution throughout training by calculating:

$$
E_{s \sim \pi_{\text{inference}}(x)}[-\frac{\pi_{\text{training}}(x)}{\pi_{\text{inference}}(x)}log(\pi_{\text{training}}(x))]
$$
using the rollouts in each training global batch as Monte-Carlo samples. The ratio of $\pi$ is in the formula to importance-correct for the mismatch between the policy over the course of training in a singular GRPO step and the inference framework.

We use this to track if our models are entropy-collapsing too quickly during training (as is quite common). This is a pretty rough monte-carlo approximation, so we wouldn't recommend using this directly for an entropy bonus or otherwise backpropagating through this. You can take a look at NeMo-Aligner's [implementation](https://github.com/NVIDIA/NeMo-Aligner/blob/main/nemo_aligner/utils/distributed.py#L351) of a full entropy calculation if you're interested (WIP efficient calculation in NeMo-RL).


## Evaluate the Trained Model

Upon completion of the training process, you can refer to our [evaluation guide](eval.md) to assess model capabilities.
35 changes: 25 additions & 10 deletions nemo_rl/algorithms/loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ def __call__(

mask = token_mask * sample_mask.unsqueeze(-1)

# token_mult_prob_error
# See more details and other metrics in docs/guides/grpo.md#metrics
lp_error = torch.abs(generation_logprobs - prev_logprobs) # noqa: F841 (precommit ignore for now)
if self.loss_type == LossType.TOKEN_LEVEL:
# average over all tokens in the microbatch
Expand All @@ -149,11 +151,11 @@ def __call__(
next_token_logits, data["input_ids"]
)
else:
next_token_logits = next_token_logits[
next_token_logits_wo_last = next_token_logits[
:, :-1
] # Remove last position's logits
next_token_logprobs = torch.nn.functional.log_softmax(
next_token_logits, dim=-1
next_token_logits_wo_last, dim=-1
)
next_tokens = data.get("input_ids")[:, 1:].cuda() # Skip first token
curr_logprobs = next_token_logprobs.gather(
Expand Down Expand Up @@ -219,31 +221,40 @@ def __call__(
advantages < 0, torch.min(clip_loss, loss3), clip_loss
)

# See: docs/guides/grpo.md#importance-sampling-correction
actor_importance_weights = torch.exp(prev_logprobs - generation_logprobs)
actor_importance_weights = torch.nan_to_num(
actor_importance_weights, nan=0.0, posinf=0.0, neginf=0.0
)
if self.use_importance_sampling_correction:
# See: docs/guides/grpo.md#importance-sampling-correction
actor_importance_weights = torch.exp(prev_logprobs - generation_logprobs)
actor_importance_weights = torch.nan_to_num(
actor_importance_weights, nan=0.0, posinf=0.0, neginf=0.0
)
importance_weights_to_use = actor_importance_weights
else:
actor_importance_weights = torch.ones_like(prev_logprobs)
importance_weights_to_use = torch.ones_like(prev_logprobs)

if self.loss_type == LossType.TOKEN_LEVEL:
actor_loss = masked_mean(
actor_importance_weights * clip_loss,
importance_weights_to_use * clip_loss,
mask,
global_normalization_factor=total_valid_tokens_or_seqs,
)
else:
actor_loss = masked_mean(
masked_mean(
actor_importance_weights * clip_loss,
importance_weights_to_use * clip_loss,
token_mask,
dim=-1,
),
sample_mask,
global_normalization_factor=total_valid_tokens_or_seqs,
)

# Approximating entropy as E_{s ~ \pi_{gen}(s)}[-(\pi_{curr}/\pi_{gen})log(\pi_{curr}(s))]
# See more details and other metrics in docs/guides/grpo.md#metrics
with torch.no_grad():
Comment thread
SahilJain314 marked this conversation as resolved.
seq_entropy_approx = -masked_mean(
torch.exp(curr_logprobs - generation_logprobs) * curr_logprobs, mask
)

loss = actor_loss + kl
with torch.no_grad():
if self.loss_type == LossType.TOKEN_LEVEL:
Expand Down Expand Up @@ -277,7 +288,11 @@ def __call__(
"probs_ratio_clamped": probs_ratio_clamped,
"kl_penalty": kl.item() / self.reference_policy_kl_penalty if kl else 0,
"token_mult_prob_error": mult_prob_error,
"sampling_importance_ratio": masked_mean(
actor_importance_weights, mask
).item(),
"num_valid_samples": sample_mask.sum().item(),
"approx_entropy": seq_entropy_approx.item(),
},
)

Expand Down
60 changes: 60 additions & 0 deletions tests/unit/algorithms/test_loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -994,3 +994,63 @@ def test_clipped_pg_loss_dual_clip():
dummy_logits, data, torch.sum(data["sample_mask"] * data["token_mask"])
)
torch.testing.assert_close(actual_loss, expected_loss)


def test_clipped_pg_loss_entropy():
"""Tests approximate entropy calculation in ClippedPGLossFn."""
if not torch.cuda.is_available():
pytest.skip("No GPU available")

device = "cuda"
data, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device)

cfg = {
"ratio_clip_min": 0.2,
"ratio_clip_max": 0.2,
"ratio_clip_c": None,
"reference_policy_kl_penalty": 0.0, # Disable KL for simplicity
"disable_ppo_ratio": False,
"use_on_policy_kl_approximation": False,
"use_importance_sampling_correction": False, # This flag does not affect entropy calculation
"token_level_loss": True,
}
loss_fn = ClippedPGLossFn(cfg)

# Log probs for 3 tokens (default token_mask is [0, 1, 1, 1], so 3 unmasked after slicing)
# curr_lp_masked: log probabilities from the current policy (model output)
# gen_lp_masked: log probabilities from the generation policy (from data)
curr_lp_masked = torch.tensor([[-0.5, -1.0, -1.5]], device=device)
gen_lp_masked = torch.tensor([[-0.6, -1.1, -1.6]], device=device)

# prev_lp_masked is needed for actor loss but not directly for this entropy formula
prev_lp_masked = torch.tensor([[-0.4, -0.9, -1.4]], device=device)

data["prev_logprobs"][0, 1:] = prev_lp_masked
data["generation_logprobs"][0, 1:] = gen_lp_masked
# _create_exact_logits needs input_ids
data["input_ids"] = torch.randint(0, vocab_size, (1, seq_len), device=device)

# seq_entropy_approx = -masked_mean(torch.exp(curr_logprobs - generation_logprobs) * curr_logprobs, mask)
# curr_lp_masked represents curr_logprobs for the hand calculation.
# gen_lp_masked represents generation_logprobs.
importance_weight_factor = torch.exp(curr_lp_masked - gen_lp_masked)
entropy_terms = importance_weight_factor * curr_lp_masked
expected_entropy = -torch.mean(
entropy_terms
) # torch.mean because default mask applies to these 3 terms

dummy_logits = _create_exact_logits(
curr_lp_masked, data["input_ids"], seq_len, vocab_size, device
)
_, metrics = loss_fn(
dummy_logits,
data,
total_valid_tokens_or_seqs=torch.sum(data["sample_mask"] * data["token_mask"]),
)

torch.testing.assert_close(
torch.tensor(metrics["approx_entropy"], device=device),
expected_entropy,
rtol=1e-3,
atol=1e-5,
)