Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion verl/trainer/ppo/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1224,7 +1224,7 @@ def fit(self):
batch = batch.union(reward_tensor)

if self.config.reward_model.launch_reward_fn_async:
future_reward = compute_reward_async.remote(batch, self.config, self.tokenizer)
future_reward = compute_reward_async.remote(data=batch, reward_fn=self.reward_fn)
else:
reward_tensor, reward_extra_infos_dict = compute_reward(batch, self.reward_fn)

Expand Down
14 changes: 12 additions & 2 deletions verl/trainer/ppo/reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,20 @@ def compute_reward(data: DataProto, reward_fn):


@ray.remote(num_cpus=1)
def compute_reward_async(data: DataProto, config, tokenizer):
def compute_reward_async(data: DataProto, config=None, tokenizer=None, reward_fn=None):
"""
Load the reward manager and compute the reward for a batch of data.
This is meant to be run in a separate Ray worker.
"""
reward_fn = load_reward_manager(config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {}))
if reward_fn is None:
assert config is not None and tokenizer is not None, (
"config and tokenizer must not be None when reward_fn is None"
)
import warnings

warnings.warn("using config and tokenizer with compute_reward_async is deprecated", stacklevel=2)
reward_fn = load_reward_manager(
config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {})
)

return compute_reward(data, reward_fn)