Skip to content

Commit bd9d6a9

Browse files
authored
[Trainer] GSPO support (#120)
This PR adds support for [Group Sequence Policy Optimization (GSPO)](https://arxiv.org/abs/2507.18071), the hotness du jour from Alibaba Qwen. The implementation in this PR is loosely based on [this one](huggingface/trl#3775) from TRL. It adds an `importance_sampling_level` config option which can be `token` (PPO/GRPO) or `sequence` (GSPO). I ran a short/small GSM8k run with Qwen2.5-0.5B and the loss curves look okay: <img width="314" height="240" alt="image" src="https://github.com/user-attachments/assets/f52d7c64-416c-4419-aa96-4a03c9048007" /> However, I had to hack a few things to get this to run on Datadog's cloud infra (including changing some dependency versions) so I'd encourage one of the maintainers to reproduce these results locally before merging.
1 parent 582ffc4 commit bd9d6a9

File tree

4 files changed

+222
-4
lines changed

4 files changed

+222
-4
lines changed

skyrl-train/docs/configuration/config.rst

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ Algorithm Configuration
290290
# this adds training batch level normalization to advantages
291291
advantage_batch_normalize: false
292292
value_head_prefix: "value_head"
293-
policy_loss_type: "regular" # "regular", "dual_clip", or customizable with PolicyLossRegistry
293+
policy_loss_type: "regular" # "regular", "dual_clip", "gspo", or customizable with PolicyLossRegistry
294294
loss_reduction: "token_mean" # "token_mean", "sequence_mean"
295295
296296
# GAE parameters
@@ -315,8 +315,14 @@ Algorithm Configuration
315315
- ``algorithm.kl_loss_coef``: Coefficient for the KL divergence loss.
316316
- ``algorithm.advantage_batch_normalize``: Whether to normalize advantages by the (global) batch mean and standard deviation.
317317
- ``algorithm.value_head_prefix``: The name used to identify the value head in the critic model.
318-
- ``algorithm.policy_loss_type``: Type of PPO loss to use. Currently, we implement ``regular`` and ``dual_clip``, where ``regular`` is the vanilla PPO loss, while ``dual_clip`` is the dual clip PPO loss proposed in `this paper <https://arxiv.org/pdf/1912.09729>`_. Custom policy losses can be registered with the ``PolicyLossRegistry``.
319-
- ``algorithm.loss_reduction``: Type of PPO loss reduction to use. Currently, we support ``token_mean`` and ``sequence_mean``. ``token_mean`` matches token-level loss introduced by `DAPO <https://dapo-sia.github.io/>`_. ``sequence_mean`` computes per-sequence avg token loss, then averages over the batch.
318+
- ``algorithm.policy_loss_type``: Type of policy loss to use. Options include:
319+
320+
- ``regular``: Vanilla PPO loss with token-level importance sampling
321+
- ``dual_clip``: Dual clip PPO loss proposed in `this paper <https://arxiv.org/pdf/1912.09729>`_
322+
- ``gspo``: `Group Sequence Policy Optimization <https://arxiv.org/abs/2507.18071>`_ with sequence-level importance sampling for improved training stability. Implements "GSPO-token" variant from the paper.
323+
- Custom policy losses can be registered with the ``PolicyLossRegistry``
324+
325+
- ``algorithm.loss_reduction``: Type of loss reduction to use. Options are ``token_mean`` and ``sequence_mean``. ``token_mean`` matches token-level loss introduced by `DAPO <https://dapo-sia.github.io/>`_. ``sequence_mean`` computes per-sequence avg token loss, then averages over the batch.
320326
- ``algorithm.lambd``: Lambda parameter for GAE.
321327
- ``algorithm.gamma``: Gamma parameter for GAE.
322328
- ``algorithm.eps_clip_low``: Lower bound for PPO clipping.

skyrl-train/skyrl_train/config/ppo_base_config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ trainer:
8888
# this adds training batch level normalization to advantages
8989
advantage_batch_normalize: false
9090
value_head_prefix: "value_head"
91-
policy_loss_type: "regular" # "regular", "dual_clip", or customizable with PolicyLossRegistry
91+
policy_loss_type: "regular" # "regular", "dual_clip", "gspo", or customizable with PolicyLossRegistry
9292
loss_reduction: "token_mean" # "token_mean", "sequence_mean"
9393
# GAE parameters
9494
lambd: 1.0

skyrl-train/skyrl_train/utils/ppo_utils.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,7 @@ class AdvantageEstimatorRegistry(BaseFunctionRegistry):
397397
class PolicyLossType(StrEnum):
398398
REGULAR = "regular"
399399
DUAL_CLIP = "dual_clip"
400+
GSPO = "gspo"
400401

401402

402403
class PolicyLossRegistry(BaseFunctionRegistry):
@@ -483,6 +484,65 @@ def ppo_policy_loss(
483484
return loss, clip_ratio
484485

485486

487+
@register_policy_loss(PolicyLossType.GSPO)
488+
def gspo_policy_loss(
489+
log_probs: torch.Tensor,
490+
old_log_probs: torch.Tensor,
491+
advantages: torch.Tensor,
492+
config: DictConfig,
493+
loss_mask: Optional[torch.Tensor] = None,
494+
) -> Tuple[torch.Tensor, float]:
495+
"""
496+
GSPO (Group Sequence Policy Optimization) policy loss function,
497+
as proposed in https://arxiv.org/abs/2507.18071.
498+
499+
This implements sequence-level importance sampling instead of token-level importance sampling.
500+
The key difference is that importance weights are computed at the sequence level and then
501+
applied uniformly across all tokens in the sequence. This can lead to more stable training
502+
dynamics by reducing the variance in clipping behavior within sequences.
503+
504+
The variant of GSPO used here is GSPO-token, a generalization which allows for token-level
505+
advantages [equations 14 and 15 in the paper].
506+
"""
507+
# GSPO must use sequence_mean reduction
508+
loss_reduction = config.loss_reduction
509+
if loss_reduction != "sequence_mean":
510+
# The GSPO paper uses sequence_mean reduction; there's no reason
511+
# why a user couldn't use token_mean reduction, but
512+
# it's not clear whether it would be stable or not.
513+
from loguru import logger as logger_ # have to do lazy import to avoid pickling error
514+
515+
logger_.warning(f"With GSPO it's recommended to use 'sequence_mean' loss reduction; got {loss_reduction}")
516+
517+
# Compute log ratios
518+
log_ratio = log_probs - old_log_probs
519+
520+
# Key GSPO innovation: sequence-level importance sampling
521+
# Instead of using per-token ratios, compute sequence-averaged ratios
522+
log_importance_weights = masked_mean(log_ratio, loss_mask, dim=-1).unsqueeze(-1)
523+
524+
# s_i,t(θ) = sg[s_i(θ)] · π_θ(y_i,t|x, y_i,<t) / sg[π_θ(y_i,t|x, y_i,<t)]
525+
# In log space: log(s_i,t(θ)) = sg[log(s_i(θ))] + log_probs - sg[log_probs]
526+
# note: we put the addition at the end to avoid precision issues,
527+
# per https://github.com/volcengine/verl/pull/2775#discussion_r2241500280
528+
log_token_importance_weights = log_probs - log_probs.detach() + log_importance_weights.detach()
529+
# clip to avoid overflow
530+
log_token_importance_weights = torch.clamp(log_token_importance_weights, max=10)
531+
ratio = torch.exp(log_token_importance_weights)
532+
533+
# Standard PPO surrogate objective with sequence-level importance weights
534+
surr1 = ratio * advantages
535+
surr2 = ratio.clamp(1 - config.eps_clip_low, 1 + config.eps_clip_high) * advantages
536+
loss = -torch.min(surr1, surr2)
537+
538+
# Compute clipping ratio for monitoring
539+
clip_ratio = masked_mean((-surr2 > -surr1).float(), loss_mask).mean().detach().item()
540+
541+
loss = reduce_loss(loss, loss_mask, loss_reduction)
542+
543+
return loss, clip_ratio
544+
545+
486546
def reduce_loss(
487547
loss: torch.Tensor, loss_mask: Optional[torch.Tensor], loss_reduction: Literal["token_mean", "sequence_mean"]
488548
) -> torch.Tensor:

skyrl-train/tests/cpu/algorithms/test_losses.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import torch
99
from omegaconf import DictConfig
1010
from skyrl_train.utils.ppo_utils import PolicyLossRegistry
11+
from skyrl_train.utils import masked_mean
1112

1213

1314
# Adapted a good test from NeMO-RL
@@ -214,3 +215,154 @@ def test_policy_loss_reduction_edge_cases():
214215
# Should handle zero mask gracefully (due to +1e-8 in denominator)
215216
assert torch.isfinite(loss_token_masked)
216217
assert torch.isfinite(loss_seq_masked)
218+
219+
220+
def test_gspo_importance_sampling_levels():
221+
"""Tests GSPO policy loss function with sequence-level importance sampling.
222+
223+
This test focuses on GSPO's key benefit: stabilizing clipping behavior through sequence-level
224+
importance sampling, which should lead to more consistent training dynamics compared to
225+
token-level importance sampling in standard PPO.
226+
"""
227+
228+
device = "cpu"
229+
230+
clip_eps_low = 0.2
231+
clip_eps_high = 0.2
232+
233+
# Create test data with varied sequence lengths and extreme ratios to test clipping stability
234+
# GSPO's benefit is most apparent with sequences of different lengths and high variance
235+
advantages = torch.tensor(
236+
[
237+
[1.5, 2.0, 1.0, 0.8, 0.5, 0.0, 0.0, 0.0], # long sequence: 5 valid tokens
238+
[3.0, 1.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], # short sequence: 2 valid tokens
239+
[0.5, 0.8, 1.2, 2.5, 0.0, 0.0, 0.0, 0.0], # medium sequence: 4 valid tokens
240+
],
241+
device=device,
242+
)
243+
244+
old_log_probs = torch.tensor(
245+
[
246+
[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0],
247+
[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0],
248+
[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0],
249+
],
250+
device=device,
251+
)
252+
253+
# Create extreme log probability ratios to trigger significant clipping
254+
# This tests GSPO's stability benefits under conditions that would cause unstable clipping
255+
log_probs = torch.tensor(
256+
[
257+
[0.2, -2.5, -0.3, 0.1, -1.8, -1.0, -1.0, -1.0], # high variance within sequence
258+
[0.8, -0.2, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0], # extreme ratios (exp(1.8)≈6.0, exp(0.8)≈2.2)
259+
[-0.5, 0.3, -1.7, 0.4, -1.0, -1.0, -1.0, -1.0], # mixed extreme values
260+
],
261+
device=device,
262+
)
263+
264+
# Create masks for different sequence lengths (key for testing length normalization)
265+
loss_mask = torch.tensor(
266+
[
267+
[1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0], # 5 tokens
268+
[1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], # 2 tokens
269+
[1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0], # 4 tokens
270+
],
271+
device=device,
272+
)
273+
274+
# Test standard PPO (token-level importance sampling)
275+
ppo_config = DictConfig(
276+
{
277+
"eps_clip_low": clip_eps_low,
278+
"eps_clip_high": clip_eps_high,
279+
"clip_ratio_c": 3.0,
280+
"policy_loss_type": "regular",
281+
"loss_reduction": "token_mean",
282+
}
283+
)
284+
ppo_loss_fn = PolicyLossRegistry.get("regular")
285+
loss_token, _ = ppo_loss_fn(log_probs, old_log_probs, advantages, ppo_config, loss_mask)
286+
287+
# Test GSPO (sequence-level importance sampling)
288+
gspo_config = DictConfig(
289+
{
290+
"eps_clip_low": clip_eps_low,
291+
"eps_clip_high": clip_eps_high,
292+
"clip_ratio_c": 3.0,
293+
"policy_loss_type": "gspo",
294+
"loss_reduction": "sequence_mean", # GSPO recommended reduction
295+
}
296+
)
297+
gspo_loss_fn = PolicyLossRegistry.get("gspo")
298+
loss_sequence, _ = gspo_loss_fn(log_probs, old_log_probs, advantages, gspo_config, loss_mask)
299+
300+
# Manual calculation for token-level (standard PPO)
301+
log_ratio = log_probs - old_log_probs
302+
ratio_token = log_ratio.exp()
303+
surr1_token = ratio_token * advantages
304+
surr2_token = ratio_token.clamp(1 - clip_eps_low, 1 + clip_eps_high) * advantages
305+
loss_per_token_token = -torch.min(surr1_token, surr2_token)
306+
expected_token = (loss_per_token_token * loss_mask).sum() / (loss_mask.sum() + 1e-8)
307+
308+
# Calculate token-level clipping ratio
309+
is_clipped_token = (-surr2_token > -surr1_token) & (loss_mask.bool())
310+
clip_ratio_token = is_clipped_token.float().sum() / loss_mask.sum()
311+
312+
# Manual calculation for sequence-level (GSPO)
313+
# First compute sequence-level importance weights (key GSPO innovation)
314+
log_importance_weights_seq = masked_mean(log_ratio, loss_mask, dim=-1).unsqueeze(-1)
315+
316+
# GSPO uses stop gradients: s_i,t(θ) = sg[s_i(θ)] · π_θ(y_i,t|x, y_i,<t) / sg[π_θ(y_i,t|x, y_i,<t)]
317+
# In log space: log(s_i,t(θ)) = sg[log(s_i(θ))] + log_probs - sg[log_probs]
318+
ratio_sequence = torch.exp(log_importance_weights_seq.detach() + log_probs - log_probs.detach())
319+
surr1_sequence = ratio_sequence * advantages
320+
surr2_sequence = ratio_sequence.clamp(1 - clip_eps_low, 1 + clip_eps_high) * advantages
321+
loss_per_token_sequence = -torch.min(surr1_sequence, surr2_sequence)
322+
# GSPO uses sequence_mean reduction
323+
expected_sequence = masked_mean(loss_per_token_sequence, loss_mask, dim=-1).mean()
324+
325+
# Calculate sequence-level clipping ratio
326+
is_clipped_sequence = (-surr2_sequence > -surr1_sequence) & (loss_mask.bool())
327+
clip_ratio_sequence = is_clipped_sequence.float().sum() / loss_mask.sum()
328+
329+
# Verify loss calculations
330+
torch.testing.assert_close(loss_token, expected_token, rtol=1e-5, atol=1e-8)
331+
torch.testing.assert_close(loss_sequence, expected_sequence, rtol=1e-5, atol=1e-8)
332+
333+
# Core GSPO benefit test: Different clipping behavior
334+
# GSPO should produce different clipping patterns due to sequence-level importance sampling
335+
assert not torch.allclose(
336+
clip_ratio_token, clip_ratio_sequence, rtol=1e-2
337+
), f"Clipping ratios should differ: token={clip_ratio_token:.4f} vs sequence={clip_ratio_sequence:.4f}"
338+
339+
# Test stability: sequence-level should smooth out extreme per-token variations
340+
# Check that sequence-level ratios have lower variance within each sequence
341+
token_ratio_variance = torch.var(ratio_token * loss_mask, dim=-1).mean()
342+
sequence_ratio_variance = torch.var(ratio_sequence * loss_mask, dim=-1).mean()
343+
344+
# The key insight: GSPO should reduce within-sequence variance by using sequence-averaged ratios
345+
assert sequence_ratio_variance < token_ratio_variance, (
346+
f"GSPO should reduce ratio variance: sequence={sequence_ratio_variance:.4f} < "
347+
f"token={token_ratio_variance:.4f}"
348+
)
349+
350+
# Token-level and sequence-level should give different results due to different importance weighting
351+
assert not torch.allclose(
352+
loss_token, loss_sequence, rtol=1e-3
353+
), f"Loss values should differ: token={loss_token:.6f} vs sequence={loss_sequence:.6f}"
354+
355+
# Test length normalization effect: sequences with different lengths should be handled more uniformly
356+
# This is a key stability benefit of GSPO mentioned in the paper
357+
seq_lengths = loss_mask.sum(dim=-1) # [5, 2, 4]
358+
359+
# In GSPO, the sequence-level importance weights should be the same across all tokens in a sequence
360+
# This should make the treatment more uniform across different sequence lengths
361+
for seq_idx in range(log_importance_weights_seq.shape[0]):
362+
seq_len = int(seq_lengths[seq_idx])
363+
if seq_len > 1:
364+
# All importance weights within a sequence should be identical (GSPO property)
365+
seq_weights = log_importance_weights_seq[seq_idx, :seq_len]
366+
assert torch.allclose(
367+
seq_weights, seq_weights[0], rtol=1e-6
368+
), f"GSPO should have uniform importance weights within sequence {seq_idx}"

0 commit comments

Comments
 (0)