|
8 | 8 | import torch |
9 | 9 | from omegaconf import DictConfig |
10 | 10 | from skyrl_train.utils.ppo_utils import PolicyLossRegistry |
| 11 | +from skyrl_train.utils import masked_mean |
11 | 12 |
|
12 | 13 |
|
13 | 14 | # Adapted a good test from NeMO-RL |
@@ -214,3 +215,154 @@ def test_policy_loss_reduction_edge_cases(): |
214 | 215 | # Should handle zero mask gracefully (due to +1e-8 in denominator) |
215 | 216 | assert torch.isfinite(loss_token_masked) |
216 | 217 | assert torch.isfinite(loss_seq_masked) |
| 218 | + |
| 219 | + |
| 220 | +def test_gspo_importance_sampling_levels(): |
| 221 | + """Tests GSPO policy loss function with sequence-level importance sampling. |
| 222 | +
|
| 223 | + This test focuses on GSPO's key benefit: stabilizing clipping behavior through sequence-level |
| 224 | + importance sampling, which should lead to more consistent training dynamics compared to |
| 225 | + token-level importance sampling in standard PPO. |
| 226 | + """ |
| 227 | + |
| 228 | + device = "cpu" |
| 229 | + |
| 230 | + clip_eps_low = 0.2 |
| 231 | + clip_eps_high = 0.2 |
| 232 | + |
| 233 | + # Create test data with varied sequence lengths and extreme ratios to test clipping stability |
| 234 | + # GSPO's benefit is most apparent with sequences of different lengths and high variance |
| 235 | + advantages = torch.tensor( |
| 236 | + [ |
| 237 | + [1.5, 2.0, 1.0, 0.8, 0.5, 0.0, 0.0, 0.0], # long sequence: 5 valid tokens |
| 238 | + [3.0, 1.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], # short sequence: 2 valid tokens |
| 239 | + [0.5, 0.8, 1.2, 2.5, 0.0, 0.0, 0.0, 0.0], # medium sequence: 4 valid tokens |
| 240 | + ], |
| 241 | + device=device, |
| 242 | + ) |
| 243 | + |
| 244 | + old_log_probs = torch.tensor( |
| 245 | + [ |
| 246 | + [-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0], |
| 247 | + [-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0], |
| 248 | + [-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0], |
| 249 | + ], |
| 250 | + device=device, |
| 251 | + ) |
| 252 | + |
| 253 | + # Create extreme log probability ratios to trigger significant clipping |
| 254 | + # This tests GSPO's stability benefits under conditions that would cause unstable clipping |
| 255 | + log_probs = torch.tensor( |
| 256 | + [ |
| 257 | + [0.2, -2.5, -0.3, 0.1, -1.8, -1.0, -1.0, -1.0], # high variance within sequence |
| 258 | + [0.8, -0.2, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0], # extreme ratios (exp(1.8)≈6.0, exp(0.8)≈2.2) |
| 259 | + [-0.5, 0.3, -1.7, 0.4, -1.0, -1.0, -1.0, -1.0], # mixed extreme values |
| 260 | + ], |
| 261 | + device=device, |
| 262 | + ) |
| 263 | + |
| 264 | + # Create masks for different sequence lengths (key for testing length normalization) |
| 265 | + loss_mask = torch.tensor( |
| 266 | + [ |
| 267 | + [1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0], # 5 tokens |
| 268 | + [1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], # 2 tokens |
| 269 | + [1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0], # 4 tokens |
| 270 | + ], |
| 271 | + device=device, |
| 272 | + ) |
| 273 | + |
| 274 | + # Test standard PPO (token-level importance sampling) |
| 275 | + ppo_config = DictConfig( |
| 276 | + { |
| 277 | + "eps_clip_low": clip_eps_low, |
| 278 | + "eps_clip_high": clip_eps_high, |
| 279 | + "clip_ratio_c": 3.0, |
| 280 | + "policy_loss_type": "regular", |
| 281 | + "loss_reduction": "token_mean", |
| 282 | + } |
| 283 | + ) |
| 284 | + ppo_loss_fn = PolicyLossRegistry.get("regular") |
| 285 | + loss_token, _ = ppo_loss_fn(log_probs, old_log_probs, advantages, ppo_config, loss_mask) |
| 286 | + |
| 287 | + # Test GSPO (sequence-level importance sampling) |
| 288 | + gspo_config = DictConfig( |
| 289 | + { |
| 290 | + "eps_clip_low": clip_eps_low, |
| 291 | + "eps_clip_high": clip_eps_high, |
| 292 | + "clip_ratio_c": 3.0, |
| 293 | + "policy_loss_type": "gspo", |
| 294 | + "loss_reduction": "sequence_mean", # GSPO recommended reduction |
| 295 | + } |
| 296 | + ) |
| 297 | + gspo_loss_fn = PolicyLossRegistry.get("gspo") |
| 298 | + loss_sequence, _ = gspo_loss_fn(log_probs, old_log_probs, advantages, gspo_config, loss_mask) |
| 299 | + |
| 300 | + # Manual calculation for token-level (standard PPO) |
| 301 | + log_ratio = log_probs - old_log_probs |
| 302 | + ratio_token = log_ratio.exp() |
| 303 | + surr1_token = ratio_token * advantages |
| 304 | + surr2_token = ratio_token.clamp(1 - clip_eps_low, 1 + clip_eps_high) * advantages |
| 305 | + loss_per_token_token = -torch.min(surr1_token, surr2_token) |
| 306 | + expected_token = (loss_per_token_token * loss_mask).sum() / (loss_mask.sum() + 1e-8) |
| 307 | + |
| 308 | + # Calculate token-level clipping ratio |
| 309 | + is_clipped_token = (-surr2_token > -surr1_token) & (loss_mask.bool()) |
| 310 | + clip_ratio_token = is_clipped_token.float().sum() / loss_mask.sum() |
| 311 | + |
| 312 | + # Manual calculation for sequence-level (GSPO) |
| 313 | + # First compute sequence-level importance weights (key GSPO innovation) |
| 314 | + log_importance_weights_seq = masked_mean(log_ratio, loss_mask, dim=-1).unsqueeze(-1) |
| 315 | + |
| 316 | + # GSPO uses stop gradients: s_i,t(θ) = sg[s_i(θ)] · π_θ(y_i,t|x, y_i,<t) / sg[π_θ(y_i,t|x, y_i,<t)] |
| 317 | + # In log space: log(s_i,t(θ)) = sg[log(s_i(θ))] + log_probs - sg[log_probs] |
| 318 | + ratio_sequence = torch.exp(log_importance_weights_seq.detach() + log_probs - log_probs.detach()) |
| 319 | + surr1_sequence = ratio_sequence * advantages |
| 320 | + surr2_sequence = ratio_sequence.clamp(1 - clip_eps_low, 1 + clip_eps_high) * advantages |
| 321 | + loss_per_token_sequence = -torch.min(surr1_sequence, surr2_sequence) |
| 322 | + # GSPO uses sequence_mean reduction |
| 323 | + expected_sequence = masked_mean(loss_per_token_sequence, loss_mask, dim=-1).mean() |
| 324 | + |
| 325 | + # Calculate sequence-level clipping ratio |
| 326 | + is_clipped_sequence = (-surr2_sequence > -surr1_sequence) & (loss_mask.bool()) |
| 327 | + clip_ratio_sequence = is_clipped_sequence.float().sum() / loss_mask.sum() |
| 328 | + |
| 329 | + # Verify loss calculations |
| 330 | + torch.testing.assert_close(loss_token, expected_token, rtol=1e-5, atol=1e-8) |
| 331 | + torch.testing.assert_close(loss_sequence, expected_sequence, rtol=1e-5, atol=1e-8) |
| 332 | + |
| 333 | + # Core GSPO benefit test: Different clipping behavior |
| 334 | + # GSPO should produce different clipping patterns due to sequence-level importance sampling |
| 335 | + assert not torch.allclose( |
| 336 | + clip_ratio_token, clip_ratio_sequence, rtol=1e-2 |
| 337 | + ), f"Clipping ratios should differ: token={clip_ratio_token:.4f} vs sequence={clip_ratio_sequence:.4f}" |
| 338 | + |
| 339 | + # Test stability: sequence-level should smooth out extreme per-token variations |
| 340 | + # Check that sequence-level ratios have lower variance within each sequence |
| 341 | + token_ratio_variance = torch.var(ratio_token * loss_mask, dim=-1).mean() |
| 342 | + sequence_ratio_variance = torch.var(ratio_sequence * loss_mask, dim=-1).mean() |
| 343 | + |
| 344 | + # The key insight: GSPO should reduce within-sequence variance by using sequence-averaged ratios |
| 345 | + assert sequence_ratio_variance < token_ratio_variance, ( |
| 346 | + f"GSPO should reduce ratio variance: sequence={sequence_ratio_variance:.4f} < " |
| 347 | + f"token={token_ratio_variance:.4f}" |
| 348 | + ) |
| 349 | + |
| 350 | + # Token-level and sequence-level should give different results due to different importance weighting |
| 351 | + assert not torch.allclose( |
| 352 | + loss_token, loss_sequence, rtol=1e-3 |
| 353 | + ), f"Loss values should differ: token={loss_token:.6f} vs sequence={loss_sequence:.6f}" |
| 354 | + |
| 355 | + # Test length normalization effect: sequences with different lengths should be handled more uniformly |
| 356 | + # This is a key stability benefit of GSPO mentioned in the paper |
| 357 | + seq_lengths = loss_mask.sum(dim=-1) # [5, 2, 4] |
| 358 | + |
| 359 | + # In GSPO, the sequence-level importance weights should be the same across all tokens in a sequence |
| 360 | + # This should make the treatment more uniform across different sequence lengths |
| 361 | + for seq_idx in range(log_importance_weights_seq.shape[0]): |
| 362 | + seq_len = int(seq_lengths[seq_idx]) |
| 363 | + if seq_len > 1: |
| 364 | + # All importance weights within a sequence should be identical (GSPO property) |
| 365 | + seq_weights = log_importance_weights_seq[seq_idx, :seq_len] |
| 366 | + assert torch.allclose( |
| 367 | + seq_weights, seq_weights[0], rtol=1e-6 |
| 368 | + ), f"GSPO should have uniform importance weights within sequence {seq_idx}" |
0 commit comments