Changes DataPreparationActor so that we can configure it into a replay buffer#1583
Changes DataPreparationActor so that we can configure it into a replay buffer#1583finbarrtimbers wants to merge 11 commits intomainfrom
DataPreparationActor so that we can configure it into a replay buffer#1583Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a replay buffer mechanism for GRPO training to enable off-policy data reuse, featuring a new ReplayBuffer class with FIFO, uniform, and prioritized sampling strategies. The implementation includes integration with the DataPreparationActor and comprehensive unit tests. Review feedback identifies several significant issues, including a potential RuntimeError during buffer insertion when at capacity, logical inconsistencies in metric reporting that mix current generation and replayed training data, and an incorrect calculation for the packed_ratio. Additionally, there are concerns regarding the performance efficiency of the prioritized sampling fallback and the behavioral implications of the FIFO sampler when paired with FIFO eviction.
| def _compute_step_metrics( | ||
| self, result, batch, batch_stats, reward_metrics, scores, advantages, collated_data, generation_idle_wait_time | ||
| ) -> dict: | ||
| if len(result.responses) == 0: | ||
| return {"time/generation_idle_waiting_for_trainer": generation_idle_wait_time} | ||
|
|
||
| real_num_responses = len(result.responses) | ||
| expected_num_responses = self.config.num_samples_per_prompt_rollout * self.global_batch_size | ||
| unsolved_num_responses = (scores < self.config.max_possible_score).sum() | ||
| sequence_lengths = np.array([len(response) for response in result.responses]) | ||
| sequence_length_solved = ( | ||
| np.array([]) | ||
| if np.all(scores == 0) | ||
| else np.array(sequence_lengths[scores == self.config.max_possible_score]) | ||
| ) | ||
| sequence_length_unsolved = ( | ||
| np.array([]) | ||
| if np.all(scores == self.config.max_possible_score) | ||
| else np.array(sequence_lengths[scores == 0]) | ||
| ) | ||
| stop_rate = sum(int(fr == "stop") for fr in result.finish_reasons) / len(result.finish_reasons) | ||
|
|
||
| batch_metrics_dict = asdict(batch_stats) | ||
| batch_metrics_prefixed = {f"batch/{k}": v for k, v in batch_metrics_dict.items()} | ||
|
|
||
| num_packed = sum(len(cd.query_responses) for cd in collated_data) | ||
| step_metrics = { | ||
| "time/generation_idle_waiting_for_trainer": generation_idle_wait_time, | ||
| "scores": scores.mean(), | ||
| "real_batch_size_ratio": real_num_responses / expected_num_responses, | ||
| "unsolved_batch_size_ratio": unsolved_num_responses / real_num_responses, | ||
| "packed_ratio": num_packed / real_num_responses, |
There was a problem hiding this comment.
When using the replay buffer, _compute_step_metrics produces inconsistent metrics because it mixes data from the newly generated rollouts (result) and the sampled training batch (scores, advantages, collated_data).
Specifically:
scoresandadvantagesreflect the replayed data being used for training.real_num_responses,sequence_lengths, andstop_ratereflect the newly generated data from the current step.packed_ratio(line 1500) is calculated incorrectly:num_packedcounts the number of micro-batches (the length of thequery_responseslist) rather than the total number of samples. It should be the sum of the batch sizes of all tensors within the micro-batches.
This inconsistency makes it difficult to track the performance of the model's current generation versus the quality of the training data. Consider separating these into distinct metric categories (e.g., train/ vs gen/) and fixing the packed_ratio calculation.
| def insert(self, groups: list[ReplayGroup]) -> None: | ||
| for group in groups: | ||
| if group.group_id in self._groups: | ||
| self._remove(group.group_id) | ||
| self._groups[group.group_id] = group | ||
| if self._sum_tree is not None: | ||
| self._sum_tree.add(group.group_id, group.priority) | ||
| self._evict_overflow() |
There was a problem hiding this comment.
The insert method will raise a RuntimeError when the buffer is at capacity and a new group is added. This is because self._sum_tree.add (line 133) checks if the tree is full and raises an exception before self._evict_overflow() (line 134) is called to make space.
Additionally, even if _sum_tree is None, self._groups will grow beyond capacity before eviction, which is inconsistent with the intended behavior of a fixed-capacity buffer. Eviction should happen before or during the insertion process to ensure capacity constraints are respected.
| def insert(self, groups: list[ReplayGroup]) -> None: | |
| for group in groups: | |
| if group.group_id in self._groups: | |
| self._remove(group.group_id) | |
| self._groups[group.group_id] = group | |
| if self._sum_tree is not None: | |
| self._sum_tree.add(group.group_id, group.priority) | |
| self._evict_overflow() | |
| def insert(self, groups: list[ReplayGroup]) -> None: | |
| for group in groups: | |
| if group.group_id in self._groups: | |
| self._remove(group.group_id) | |
| elif len(self._groups) >= self.capacity: | |
| self._remove(next(iter(self._groups))) | |
| self._groups[group.group_id] = group | |
| if self._sum_tree is not None: | |
| self._sum_tree.add(group.group_id, group.priority) |
| keys = list(self._groups.keys())[:k] | ||
| return [self._groups[key] for key in keys] |
There was a problem hiding this comment.
The _sample_fifo implementation always returns the oldest k items in the buffer. If eviction_type is set to FIFO (evict on overflow), this sampler will repeatedly return the same items until they are pushed out by new insertions. This effectively prevents the model from training on more recent data until the buffer has completely cycled.
If the goal is to use the buffer as a queue, after_n_samples eviction should be used. If the goal is replay, UNIFORM or PRIORITIZED are more appropriate. Consider clarifying the intended use case for SamplerType.FIFO or implementing it as a sliding window.
| remaining_keys = [key for key in self._groups if key not in sampled_keys] | ||
| self.rng.shuffle(remaining_keys) | ||
| for key in remaining_keys[: k - len(sampled)]: | ||
| sampled.append(self._groups[key]) |
There was a problem hiding this comment.
This fallback logic is inefficient for large buffers as it performs an list(self._groups.keys()) (used here and in other samplers) creates a full list copy on every call, which can be expensive as the buffer size grows.
Consider maintaining a separate list of keys for sampling or using a more efficient approach to handle the case where prioritized sampling fails to find enough unique keys within max_attempts.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 73ea077270
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| if self._sum_tree is not None: | ||
| self._sum_tree.add(group.group_id, group.priority) | ||
| self._evict_overflow() |
There was a problem hiding this comment.
Evict before adding prioritized groups at capacity
When sampler_type is prioritized, insert() calls self._sum_tree.add(...) before any overflow eviction. If the buffer is already full, SumTree.add raises RuntimeError("SumTree is full"), so the first post-capacity insert crashes replay training instead of evicting according to policy. This is reproducible as soon as a prioritized buffer reaches capacity and receives another group.
Useful? React with 👍 / 👎.
| sequence_lengths = np.array([len(response) for response in result.responses]) | ||
| sequence_length_solved = ( | ||
| np.array([]) | ||
| if np.all(scores == 0) | ||
| else np.array(sequence_lengths[scores == self.config.max_possible_score]) |
There was a problem hiding this comment.
Keep metric arrays aligned after replay truncation masking
In replay mode with mask_truncated_completions=True, scores is filtered against sampled finish_reasons, but _compute_step_metrics still derives sequence_lengths from unfiltered result.responses. If any sampled completion is truncated, boolean indexing like sequence_lengths[scores == ...] uses a mask length that no longer matches sequence_lengths, which raises an IndexError and stops data preparation for that step.
Useful? React with 👍 / 👎.
73ea077 to
1b1fb6f
Compare
…bine_processed_results Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
… Opus 4.6 <noreply@anthropic.com>
…pus 4.6 <noreply@anthropic.com>
…pytest filter to GPU test scripts Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…pus 4.6 <noreply@anthropic.com>
…y: Claude Opus 4.6 <noreply@anthropic.com>
…Claude Opus 4.6 <noreply@anthropic.com>
…ing Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…er_lib.ProcessedResult Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
4beb3ad to
4c58ed3
Compare
Inspired by Reverb.