Skip to content

[RL] Changes to enable compilation for trainer#2568

Open
Lucaskabela wants to merge 2 commits intopytorch:mainfrom
Lucaskabela:lucaskabela/enable_trainer_compile_03_10
Open

[RL] Changes to enable compilation for trainer#2568
Lucaskabela wants to merge 2 commits intopytorch:mainfrom
Lucaskabela:lucaskabela/enable_trainer_compile_03_10

Conversation

@Lucaskabela
Copy link
Contributor

@Lucaskabela Lucaskabela commented Mar 13, 2026

Summary

In this PR, we enable naive, JIT style torch.compile for the RL policy trainer. This is the first step towards speeding up the trainer model. Changes are:

  1. Wiring through compilation config:
  • Added TrainerCompileConfig dataclass with enable (bool) and backend (str, default "eager") fields
  • Added compile field to PolicyTrainer.Config with compile and aot_eager backend
  • Added _compile_model() method that calls .compile(backend=..., fullgraph=True) on each transformer
    layer -> This is crticial, as torch.compile() results in logit changes
  1. config_registry.py — Enable compile by default in configs
  • Both rl_grpo_qwen3_0_6b and rl_grpo_qwen3_debug configs now set
    compile=TrainerCompileConfig(enable=True). Default backend is 'eager'
  1. vllm_compat/models/attention.py — Make flash-attention compile-compatible
  • Moved the FlashAttnWithBackward autograd function out of the forward() method (nested classes
    can't be traced by the compiler) into a module-level FlashAttnVarlenFunction
  • Registered the flash-attention forward as a torch.library.custom_op (rl::flash_attn_varlen_fwd)
    with a fake implementation, so AOT Autograd can trace through it with FakeTensors
  • Simplified the call site in VLLMCompatibleFlashAttention.forward() to use the new function

Test Plan

python torchtitan/experiments/rl/unified/simple_grpo_sum_digits.py --module rl.unified --config rl_grpo_qwen3_0_6b --hf_assets_path=torchtitan/experiments/rl/example_checkpoint/Qwen3-0.6B

Results in the same losses as on main - the timing is now like:

Main

[2026-03-13 12:13:11] INFO simple_grpo_sum_digits.py:401: [actor=<root>] Cumulative Timing | Generator: 22.6s | Optimizer: 0.1s | Trainer: 148.6s | WeightSync: 119.4s | Total: 290.7s

Changes

[2026-03-13 12:03:26] INFO simple_grpo_sum_digits.py:401: [actor=<root>] Cumulative Timing | Generator: 22.1s | Optimizer: 0.1s | Trainer: 103.2s | WeightSync: 119.2s | Total: 244.6s

So we save ~50s of runtime

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Mar 13, 2026
@Lucaskabela Lucaskabela marked this pull request as ready for review March 13, 2026 19:19
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we go with pytorch varlen, do we still need to worry about this file? cc @wwwjn

@Lucaskabela Lucaskabela requested a review from wwwjn March 13, 2026 20:39
@Lucaskabela Lucaskabela force-pushed the lucaskabela/enable_trainer_compile_03_10 branch from 8a62589 to be75f04 Compare March 13, 2026 22:14
@Lucaskabela Lucaskabela force-pushed the lucaskabela/enable_trainer_compile_03_10 branch from be75f04 to 520d314 Compare March 13, 2026 23:50
@Lucaskabela Lucaskabela requested a review from tianyu-l March 14, 2026 00:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants