Skip to content

Add support for MaxRL #5026

Open
catherinelee274 wants to merge 15 commits intohuggingface:mainfrom
catherinelee274:clee_maxrl
Open

Add support for MaxRL #5026
catherinelee274 wants to merge 15 commits intohuggingface:mainfrom
catherinelee274:clee_maxrl

Conversation

@catherinelee274
Copy link

@catherinelee274 catherinelee274 commented Feb 9, 2026

What does this PR do?

Adds Maxrl which is a variant of grpo with p-normalization.
Fixes #5025

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@catherinelee274 catherinelee274 changed the title Add support for MaxRL [WIP] Add support for MaxRL Feb 17, 2026
@catherinelee274 catherinelee274 marked this pull request as ready for review February 17, 2026 06:25
@LeonEricsson
Copy link
Collaborator

Doesn't MaxRL reduce to simply changing the advantage normalization denominator from std(r) to mean(r)?

# GRPO
A_i = (r_i - mean(r)) / (std(r) + eps)

# MaxRL
A_i = (r_i - mean(r)) / (mean(r) + eps)

If so, this fits naturally as a flag in the existing GRPO trainer rather than a dedicated experimental module.

Copy link
Collaborator

@LeonEricsson LeonEricsson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you mind writing a paper index section for MaxRL as well?

Comment on lines +1398 to +1424
def test_maxrl_advantage_normalization(self):
"""Unit test: MaxRL uses A_i = (r_i - mean(r)) / (mean(r) + eps), not std(r)."""
# rewards for two groups of 3
rewards = torch.tensor([1.0, 0.0, 0.0, 1.0, 1.0, 0.0])
num_generations = 3

mean_grouped = rewards.view(-1, num_generations).mean(dim=1)
mean_grouped = mean_grouped.repeat_interleave(num_generations, dim=0)

eps = 1e-4
advantages = (rewards - mean_grouped) / (mean_grouped + eps)

# group 0: mean=1/3, advantages = (r - 1/3) / (1/3 + eps)
# group 1: mean=2/3, advantages = (r - 2/3) / (2/3 + eps)
mean0 = torch.tensor(1.0 / 3.0)
mean1 = torch.tensor(2.0 / 3.0)
expected = torch.tensor(
[
(1.0 - mean0) / (mean0 + eps),
(0.0 - mean0) / (mean0 + eps),
(0.0 - mean0) / (mean0 + eps),
(1.0 - mean1) / (mean1 + eps),
(1.0 - mean1) / (mean1 + eps),
(0.0 - mean1) / (mean1 + eps),
]
)
torch.testing.assert_close(advantages, expected)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def test_maxrl_advantage_normalization(self):
"""Unit test: MaxRL uses A_i = (r_i - mean(r)) / (mean(r) + eps), not std(r)."""
# rewards for two groups of 3
rewards = torch.tensor([1.0, 0.0, 0.0, 1.0, 1.0, 0.0])
num_generations = 3
mean_grouped = rewards.view(-1, num_generations).mean(dim=1)
mean_grouped = mean_grouped.repeat_interleave(num_generations, dim=0)
eps = 1e-4
advantages = (rewards - mean_grouped) / (mean_grouped + eps)
# group 0: mean=1/3, advantages = (r - 1/3) / (1/3 + eps)
# group 1: mean=2/3, advantages = (r - 2/3) / (2/3 + eps)
mean0 = torch.tensor(1.0 / 3.0)
mean1 = torch.tensor(2.0 / 3.0)
expected = torch.tensor(
[
(1.0 - mean0) / (mean0 + eps),
(0.0 - mean0) / (mean0 + eps),
(0.0 - mean0) / (mean0 + eps),
(1.0 - mean1) / (mean1 + eps),
(1.0 - mean1) / (mean1 + eps),
(0.0 - mean1) / (mean1 + eps),
]
)
torch.testing.assert_close(advantages, expected)

This is fine as an offline test for the PR but doesn't need to be a part of the test suite

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed.

Comment on lines +1394 to +1396
# ------------------------------------------------------------------
# MaxRL tests (scale_rewards="mean")
# ------------------------------------------------------------------
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# ------------------------------------------------------------------
# MaxRL tests (scale_rewards="mean")
# ------------------------------------------------------------------

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed

Comment on lines +1469 to +1493
def test_maxrl_training_with_eval(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

training_args = GRPOConfig(
output_dir=self.tmp_dir,
learning_rate=0.1,
per_device_train_batch_size=3,
num_generations=3,
max_completion_length=8,
scale_rewards="mean",
eval_strategy="steps",
eval_steps=2,
report_to="none",
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
args=training_args,
train_dataset=dataset,
eval_dataset=dataset,
)

trainer.train()

assert trainer.state.log_history[-1]["train_loss"] is not None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def test_maxrl_training_with_eval(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
training_args = GRPOConfig(
output_dir=self.tmp_dir,
learning_rate=0.1,
per_device_train_batch_size=3,
num_generations=3,
max_completion_length=8,
scale_rewards="mean",
eval_strategy="steps",
eval_steps=2,
report_to="none",
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
args=training_args,
train_dataset=dataset,
eval_dataset=dataset,
)
trainer.train()
assert trainer.state.log_history[-1]["train_loss"] is not None

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed.

Comment on lines +1426 to +1439
def test_maxrl_advantage_zero_mean(self):
"""When all rewards in a group are 0, advantages should be 0 (not NaN)."""
rewards = torch.tensor([0.0, 0.0, 0.0])
num_generations = 3

mean_grouped = rewards.view(-1, num_generations).mean(dim=1)
mean_grouped = mean_grouped.repeat_interleave(num_generations, dim=0)

eps = 1e-4
advantages = (rewards - mean_grouped) / (mean_grouped + eps)

# numerator is 0 for all, denominator is eps → advantages all 0
assert not torch.isnan(advantages).any(), "advantages must not contain NaN"
torch.testing.assert_close(advantages, torch.zeros(3))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def test_maxrl_advantage_zero_mean(self):
"""When all rewards in a group are 0, advantages should be 0 (not NaN)."""
rewards = torch.tensor([0.0, 0.0, 0.0])
num_generations = 3
mean_grouped = rewards.view(-1, num_generations).mean(dim=1)
mean_grouped = mean_grouped.repeat_interleave(num_generations, dim=0)
eps = 1e-4
advantages = (rewards - mean_grouped) / (mean_grouped + eps)
# numerator is 0 for all, denominator is eps → advantages all 0
assert not torch.isnan(advantages).any(), "advantages must not contain NaN"
torch.testing.assert_close(advantages, torch.zeros(3))

same for this. These tests are tautological, they reimplement behavior inline rather than testing trl code.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed.

Comment on lines +1495 to +1527
def test_maxrl_training_multiple_reward_funcs(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

def reward_func1(completions, **kwargs):
return [1.0] * len(completions)

def reward_func2(completions, **kwargs):
return [len(c) * 0.01 for c in completions]

training_args = GRPOConfig(
output_dir=self.tmp_dir,
learning_rate=0.1,
per_device_train_batch_size=3,
num_generations=3,
max_completion_length=8,
scale_rewards="mean",
report_to="none",
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs=[reward_func1, reward_func2],
args=training_args,
train_dataset=dataset,
)

previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

trainer.train()

assert trainer.state.log_history[-1]["train_loss"] is not None
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems like a duplicate of test_maxrl_training since we're not actually verifying anything related to the multiple reward functions.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed.

Comment on lines +1555 to +1581
def test_maxrl_training_conversational(self):
dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_only", split="train")

training_args = GRPOConfig(
output_dir=self.tmp_dir,
learning_rate=0.1,
per_device_train_batch_size=3,
num_generations=3,
max_completion_length=8,
scale_rewards="mean",
report_to="none",
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
args=training_args,
train_dataset=dataset,
)

previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

trainer.train()

assert trainer.state.log_history[-1]["train_loss"] is not None
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar comment to test_maxrl_training_multiple_reward_funcs. I also don't see why the combination of MaxRL and conversatinal training need a dedicated test, isn't this already covered by existing conversational tests?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed. Yes perhaps maxrl is too small of a new change to warrant new tests.

Comment on lines +1530 to +1553
def test_maxrl_training_peft(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

training_args = GRPOConfig(
output_dir=self.tmp_dir,
learning_rate=0.1,
per_device_train_batch_size=3,
num_generations=3,
max_completion_length=8,
scale_rewards="mean",
report_to="none",
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
args=training_args,
train_dataset=dataset,
peft_config=LoraConfig(task_type="CAUSAL_LM"),
)

trainer.train()

assert trainer.state.log_history[-1]["train_loss"] is not None
assert isinstance(trainer.model, PeftModel)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same thoughts as test_maxrl_training_multiple_reward_funcs and test_maxrl_training_conversational

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

Remove test_test_maxrl_advantage_normalization, test_maxrl_advantage_zero_mean  as they do not test TRL code
 test_maxrl_training_conversational
Copy link
Collaborator

@LeonEricsson LeonEricsson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

final comments. then i'm satisfied.
needs a maintainers approval before merging.

- Remoe # MaxRL: A_i = (r_i - mean(r)) / (mean(r) + eps) comment

- removed comment in grpo_trainer, due to us having a paper index already
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Maximum Likelihood Reinforcement Learning

2 participants