Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -571,8 +571,17 @@ def update_operator_states(topology: Topology) -> None:
"""Update operator states accordingly for newly completed tasks.
Should be called after `process_completed_tasks()`."""

# Call inputs_done() on ops where no more inputs are coming.
for op, op_state in topology.items():
# Drain upstream output queue if current operator is execution finished.
# This is needed when the limit is reached, and `mark_execution_finished`
# is called manually.
if op.execution_finished():
for idx, dep in enumerate(op.input_dependencies):
upstream_state = topology[dep]
# Drain upstream output queue
upstream_state.output_queue.clear()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Downstream Completion Triggers Premature Queue Clearing

Unconditionally clearing an upstream operator's output queue when any downstream operator finishes execution can lead to data loss. In fan-out scenarios, this starves other active downstream operators that still depend on that data.

Fix in Cursor Fix in Web


# Call inputs_done() on ops where no more inputs are coming.
if op_state.inputs_done_called:
continue
all_inputs_done = True
Expand Down
56 changes: 56 additions & 0 deletions python/ray/data/tests/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import pytest

import ray
Expand Down Expand Up @@ -40,7 +42,9 @@
from ray.data._internal.execution.operators.task_pool_map_operator import (
TaskPoolMapOperator,
)
from ray.data._internal.execution.streaming_executor import StreamingExecutor
from ray.data._internal.execution.util import make_ref_bundles
from ray.data._internal.logical.optimizers import get_execution_plan
from ray.data._internal.output_buffer import OutputBlockSizeOption
from ray.data._internal.stats import Timer
from ray.data.block import Block, BlockAccessor
Expand Down Expand Up @@ -826,6 +830,58 @@ def test_limit_operator(ray_start_regular_shared):
assert limit_op.completed(), limit


def test_limit_operator_memory_leak_fix(ray_start_regular_shared, tmp_path):
"""Test that LimitOperator properly drains upstream output queues.

This test verifies the memory leak fix by directly using StreamingExecutor
to access the actual topology and check queued blocks after execution.
"""
for i in range(100):
data = [{"id": i * 5 + j, "value": f"row_{i * 5 + j}"} for j in range(5)]
table = pa.Table.from_pydict(
{"id": [row["id"] for row in data], "value": [row["value"] for row in data]}
)
parquet_file = tmp_path / f"test_data_{i}.parquet"
pq.write_table(table, str(parquet_file))

parquet_files = [str(tmp_path / f"test_data_{i}.parquet") for i in range(100)]

ds = (
ray.data.read_parquet(parquet_files, override_num_blocks=100)
.limit(5)
.map(lambda x: x)
)

execution_plan = ds._plan
physical_plan = get_execution_plan(execution_plan._logical_plan)

# Use StreamingExecutor directly to have access to the actual topology
executor = StreamingExecutor(DataContext.get_current())
output_iterator = executor.execute(physical_plan.dag)

# Collect all results and count rows
total_rows = 0
for bundle in output_iterator:
for block_ref in bundle.block_refs:
block = ray.get(block_ref)
total_rows += block.num_rows
assert (
total_rows == 5
), f"Expected exactly 5 rows after limit(5), but got {total_rows}"

# Find the ReadParquet operator's OpState
topology = executor._topology
read_parquet_op_state = None
for op, op_state in topology.items():
if "ReadParquet" in op.name:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The condition "ReadParquet" in op.name is likely incorrect and fragile. The name of the logical Read operator for a Parquet datasource is constructed as f"Read({datasource.get_name()})", which results in Read(Parquet). The substring "ReadParquet" is not present in this name. A more robust approach would be to match the exact name.

Suggested change
if "ReadParquet" in op.name:
if "Read(Parquet)" in op.name:

read_parquet_op_state = op_state
break

# Check the output queue size
output_queue_size = len(read_parquet_op_state.output_queue)
assert output_queue_size == 0, f"Expected 0 items, but got {output_queue_size}."


def _get_bundles(bundle: RefBundle):
output = []
for block_ref in bundle.block_refs:
Expand Down
45 changes: 45 additions & 0 deletions python/ray/data/tests/test_streaming_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,51 @@ def test_process_completed_tasks(sleep_task_ref):
o2.mark_execution_finished.assert_called_once()


def test_update_operator_states_drains_upstream():
"""Test that update_operator_states drains upstream output queues when
execution_finished() is called on a downstream operator.
"""
inputs = make_ref_bundles([[x] for x in range(10)])
o1 = InputDataBuffer(DataContext.get_current(), inputs)
o2 = MapOperator.create(
make_map_transformer(lambda block: [b * -1 for b in block]),
o1,
DataContext.get_current(),
)
o3 = MapOperator.create(
make_map_transformer(lambda block: [b * 2 for b in block]),
o2,
DataContext.get_current(),
)
topo, _ = build_streaming_topology(o3, ExecutionOptions(verbose_progress=True))

# First, populate the upstream output queues by processing some tasks
process_completed_tasks(topo, [], 0)
update_operator_states(topo)

# Verify that o1 (upstream) has output in its queue
assert (
len(topo[o1].output_queue) > 0
), "Upstream operator should have output in queue"

# Store initial queue size for verification
initial_o1_queue_size = len(topo[o1].output_queue)

# Manually mark o2 as execution finished (simulating limit operator behavior)
o2.mark_execution_finished()
assert o2.execution_finished(), "o2 should be execution finished"

# Call update_operator_states - this should drain o1's output queue
update_operator_states(topo)

# Verify that o1's output queue was drained due to o2 being execution finished
assert len(topo[o1].output_queue) == 0, (
f"Upstream operator o1 output queue should be drained when downstream o2 is execution finished. "
f"Expected 0, got {len(topo[o1].output_queue)}. "
f"Initial size was {initial_o1_queue_size}"
)


def test_get_eligible_operators_to_run():
opts = ExecutionOptions()
inputs = make_ref_bundles([[x] for x in range(1)])
Expand Down