Conversation
|
Doesn't MaxRL reduce to simply changing the advantage normalization denominator from If so, this fits naturally as a flag in the existing GRPO trainer rather than a dedicated experimental module. |
LeonEricsson
left a comment
There was a problem hiding this comment.
Would you mind writing a paper index section for MaxRL as well?
tests/test_grpo_trainer.py
Outdated
| def test_maxrl_advantage_normalization(self): | ||
| """Unit test: MaxRL uses A_i = (r_i - mean(r)) / (mean(r) + eps), not std(r).""" | ||
| # rewards for two groups of 3 | ||
| rewards = torch.tensor([1.0, 0.0, 0.0, 1.0, 1.0, 0.0]) | ||
| num_generations = 3 | ||
|
|
||
| mean_grouped = rewards.view(-1, num_generations).mean(dim=1) | ||
| mean_grouped = mean_grouped.repeat_interleave(num_generations, dim=0) | ||
|
|
||
| eps = 1e-4 | ||
| advantages = (rewards - mean_grouped) / (mean_grouped + eps) | ||
|
|
||
| # group 0: mean=1/3, advantages = (r - 1/3) / (1/3 + eps) | ||
| # group 1: mean=2/3, advantages = (r - 2/3) / (2/3 + eps) | ||
| mean0 = torch.tensor(1.0 / 3.0) | ||
| mean1 = torch.tensor(2.0 / 3.0) | ||
| expected = torch.tensor( | ||
| [ | ||
| (1.0 - mean0) / (mean0 + eps), | ||
| (0.0 - mean0) / (mean0 + eps), | ||
| (0.0 - mean0) / (mean0 + eps), | ||
| (1.0 - mean1) / (mean1 + eps), | ||
| (1.0 - mean1) / (mean1 + eps), | ||
| (0.0 - mean1) / (mean1 + eps), | ||
| ] | ||
| ) | ||
| torch.testing.assert_close(advantages, expected) |
There was a problem hiding this comment.
| def test_maxrl_advantage_normalization(self): | |
| """Unit test: MaxRL uses A_i = (r_i - mean(r)) / (mean(r) + eps), not std(r).""" | |
| # rewards for two groups of 3 | |
| rewards = torch.tensor([1.0, 0.0, 0.0, 1.0, 1.0, 0.0]) | |
| num_generations = 3 | |
| mean_grouped = rewards.view(-1, num_generations).mean(dim=1) | |
| mean_grouped = mean_grouped.repeat_interleave(num_generations, dim=0) | |
| eps = 1e-4 | |
| advantages = (rewards - mean_grouped) / (mean_grouped + eps) | |
| # group 0: mean=1/3, advantages = (r - 1/3) / (1/3 + eps) | |
| # group 1: mean=2/3, advantages = (r - 2/3) / (2/3 + eps) | |
| mean0 = torch.tensor(1.0 / 3.0) | |
| mean1 = torch.tensor(2.0 / 3.0) | |
| expected = torch.tensor( | |
| [ | |
| (1.0 - mean0) / (mean0 + eps), | |
| (0.0 - mean0) / (mean0 + eps), | |
| (0.0 - mean0) / (mean0 + eps), | |
| (1.0 - mean1) / (mean1 + eps), | |
| (1.0 - mean1) / (mean1 + eps), | |
| (0.0 - mean1) / (mean1 + eps), | |
| ] | |
| ) | |
| torch.testing.assert_close(advantages, expected) |
This is fine as an offline test for the PR but doesn't need to be a part of the test suite
tests/test_grpo_trainer.py
Outdated
| # ------------------------------------------------------------------ | ||
| # MaxRL tests (scale_rewards="mean") | ||
| # ------------------------------------------------------------------ |
There was a problem hiding this comment.
| # ------------------------------------------------------------------ | |
| # MaxRL tests (scale_rewards="mean") | |
| # ------------------------------------------------------------------ |
tests/test_grpo_trainer.py
Outdated
| def test_maxrl_training_with_eval(self): | ||
| dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") | ||
|
|
||
| training_args = GRPOConfig( | ||
| output_dir=self.tmp_dir, | ||
| learning_rate=0.1, | ||
| per_device_train_batch_size=3, | ||
| num_generations=3, | ||
| max_completion_length=8, | ||
| scale_rewards="mean", | ||
| eval_strategy="steps", | ||
| eval_steps=2, | ||
| report_to="none", | ||
| ) | ||
| trainer = GRPOTrainer( | ||
| model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", | ||
| reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", | ||
| args=training_args, | ||
| train_dataset=dataset, | ||
| eval_dataset=dataset, | ||
| ) | ||
|
|
||
| trainer.train() | ||
|
|
||
| assert trainer.state.log_history[-1]["train_loss"] is not None |
There was a problem hiding this comment.
| def test_maxrl_training_with_eval(self): | |
| dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") | |
| training_args = GRPOConfig( | |
| output_dir=self.tmp_dir, | |
| learning_rate=0.1, | |
| per_device_train_batch_size=3, | |
| num_generations=3, | |
| max_completion_length=8, | |
| scale_rewards="mean", | |
| eval_strategy="steps", | |
| eval_steps=2, | |
| report_to="none", | |
| ) | |
| trainer = GRPOTrainer( | |
| model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", | |
| reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", | |
| args=training_args, | |
| train_dataset=dataset, | |
| eval_dataset=dataset, | |
| ) | |
| trainer.train() | |
| assert trainer.state.log_history[-1]["train_loss"] is not None |
tests/test_grpo_trainer.py
Outdated
| def test_maxrl_advantage_zero_mean(self): | ||
| """When all rewards in a group are 0, advantages should be 0 (not NaN).""" | ||
| rewards = torch.tensor([0.0, 0.0, 0.0]) | ||
| num_generations = 3 | ||
|
|
||
| mean_grouped = rewards.view(-1, num_generations).mean(dim=1) | ||
| mean_grouped = mean_grouped.repeat_interleave(num_generations, dim=0) | ||
|
|
||
| eps = 1e-4 | ||
| advantages = (rewards - mean_grouped) / (mean_grouped + eps) | ||
|
|
||
| # numerator is 0 for all, denominator is eps → advantages all 0 | ||
| assert not torch.isnan(advantages).any(), "advantages must not contain NaN" | ||
| torch.testing.assert_close(advantages, torch.zeros(3)) |
There was a problem hiding this comment.
| def test_maxrl_advantage_zero_mean(self): | |
| """When all rewards in a group are 0, advantages should be 0 (not NaN).""" | |
| rewards = torch.tensor([0.0, 0.0, 0.0]) | |
| num_generations = 3 | |
| mean_grouped = rewards.view(-1, num_generations).mean(dim=1) | |
| mean_grouped = mean_grouped.repeat_interleave(num_generations, dim=0) | |
| eps = 1e-4 | |
| advantages = (rewards - mean_grouped) / (mean_grouped + eps) | |
| # numerator is 0 for all, denominator is eps → advantages all 0 | |
| assert not torch.isnan(advantages).any(), "advantages must not contain NaN" | |
| torch.testing.assert_close(advantages, torch.zeros(3)) |
same for this. These tests are tautological, they reimplement behavior inline rather than testing trl code.
tests/test_grpo_trainer.py
Outdated
| def test_maxrl_training_multiple_reward_funcs(self): | ||
| dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") | ||
|
|
||
| def reward_func1(completions, **kwargs): | ||
| return [1.0] * len(completions) | ||
|
|
||
| def reward_func2(completions, **kwargs): | ||
| return [len(c) * 0.01 for c in completions] | ||
|
|
||
| training_args = GRPOConfig( | ||
| output_dir=self.tmp_dir, | ||
| learning_rate=0.1, | ||
| per_device_train_batch_size=3, | ||
| num_generations=3, | ||
| max_completion_length=8, | ||
| scale_rewards="mean", | ||
| report_to="none", | ||
| ) | ||
| trainer = GRPOTrainer( | ||
| model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", | ||
| reward_funcs=[reward_func1, reward_func2], | ||
| args=training_args, | ||
| train_dataset=dataset, | ||
| ) | ||
|
|
||
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} | ||
|
|
||
| trainer.train() | ||
|
|
||
| assert trainer.state.log_history[-1]["train_loss"] is not None | ||
| for n, param in previous_trainable_params.items(): | ||
| new_param = trainer.model.get_parameter(n) | ||
| assert not torch.equal(param, new_param), f"Parameter {n} has not changed." |
There was a problem hiding this comment.
this seems like a duplicate of test_maxrl_training since we're not actually verifying anything related to the multiple reward functions.
tests/test_grpo_trainer.py
Outdated
| def test_maxrl_training_conversational(self): | ||
| dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_only", split="train") | ||
|
|
||
| training_args = GRPOConfig( | ||
| output_dir=self.tmp_dir, | ||
| learning_rate=0.1, | ||
| per_device_train_batch_size=3, | ||
| num_generations=3, | ||
| max_completion_length=8, | ||
| scale_rewards="mean", | ||
| report_to="none", | ||
| ) | ||
| trainer = GRPOTrainer( | ||
| model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", | ||
| reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", | ||
| args=training_args, | ||
| train_dataset=dataset, | ||
| ) | ||
|
|
||
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} | ||
|
|
||
| trainer.train() | ||
|
|
||
| assert trainer.state.log_history[-1]["train_loss"] is not None | ||
| for n, param in previous_trainable_params.items(): | ||
| new_param = trainer.model.get_parameter(n) | ||
| assert not torch.equal(param, new_param), f"Parameter {n} has not changed." |
There was a problem hiding this comment.
similar comment to test_maxrl_training_multiple_reward_funcs. I also don't see why the combination of MaxRL and conversatinal training need a dedicated test, isn't this already covered by existing conversational tests?
There was a problem hiding this comment.
removed. Yes perhaps maxrl is too small of a new change to warrant new tests.
tests/test_grpo_trainer.py
Outdated
| def test_maxrl_training_peft(self): | ||
| dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") | ||
|
|
||
| training_args = GRPOConfig( | ||
| output_dir=self.tmp_dir, | ||
| learning_rate=0.1, | ||
| per_device_train_batch_size=3, | ||
| num_generations=3, | ||
| max_completion_length=8, | ||
| scale_rewards="mean", | ||
| report_to="none", | ||
| ) | ||
| trainer = GRPOTrainer( | ||
| model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", | ||
| reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", | ||
| args=training_args, | ||
| train_dataset=dataset, | ||
| peft_config=LoraConfig(task_type="CAUSAL_LM"), | ||
| ) | ||
|
|
||
| trainer.train() | ||
|
|
||
| assert trainer.state.log_history[-1]["train_loss"] is not None | ||
| assert isinstance(trainer.model, PeftModel) |
There was a problem hiding this comment.
same thoughts as test_maxrl_training_multiple_reward_funcs and test_maxrl_training_conversational
Remove test_test_maxrl_advantage_normalization, test_maxrl_advantage_zero_mean as they do not test TRL code test_maxrl_training_conversational
LeonEricsson
left a comment
There was a problem hiding this comment.
final comments. then i'm satisfied.
needs a maintainers approval before merging.
- Remoe # MaxRL: A_i = (r_i - mean(r)) / (mean(r) + eps) comment - removed comment in grpo_trainer, due to us having a paper index already
What does this PR do?
Adds Maxrl which is a variant of grpo with p-normalization.
Fixes #5025
Before submitting
Pull Request section?
to it if that's the case.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.