Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions python/ray/data/_internal/execution/operators/hash_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,19 +529,22 @@ def __init__(
input_logical_ops,
)

ray_remote_args = self._get_default_aggregator_ray_remote_args(
num_partitions=target_num_partitions,
num_aggregators=num_aggregators,
total_available_cluster_resources=total_available_cluster_resources,
estimated_dataset_bytes=estimated_dataset_bytes,
)

if aggregator_ray_remote_args_override is not None:
# Set default values missing for configs missing in the override
ray_remote_args.update(aggregator_ray_remote_args_override)

self._aggregator_pool: AggregatorPool = AggregatorPool(
num_partitions=target_num_partitions,
num_aggregators=num_aggregators,
aggregation_factory=partition_aggregation_factory,
aggregator_ray_remote_args=(
aggregator_ray_remote_args_override
or self._get_default_aggregator_ray_remote_args(
num_partitions=target_num_partitions,
num_aggregators=num_aggregators,
total_available_cluster_resources=total_available_cluster_resources,
estimated_dataset_bytes=estimated_dataset_bytes,
)
),
aggregator_ray_remote_args=ray_remote_args,
data_context=data_context,
)

Expand Down Expand Up @@ -1080,6 +1083,9 @@ def _get_default_aggregator_ray_remote_args(
# nodes to prevent any single node being overloaded with a "thundering
# herd"
"scheduling_strategy": "SPREAD",
# Allow actor tasks to execute out of order by default to prevent head-of-line
# blocking scenario.
"allow_out_of_order_execution": True,
}

return remote_args
Expand Down
75 changes: 75 additions & 0 deletions python/ray/data/tests/test_hash_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class JoinTestCase:
"num_cpus": 0.25, # 4 CPUs * 25% / 4 aggregators
"memory": 1771674012,
"scheduling_strategy": "SPREAD",
"allow_out_of_order_execution": True,
},
),
# Case 2: Single partition (much higher memory overhead)
Expand All @@ -72,6 +73,7 @@ class JoinTestCase:
"num_cpus": 1.0, # 4 CPUs * 25% / 1 aggregator
"memory": 8589934592,
"scheduling_strategy": "SPREAD",
"allow_out_of_order_execution": True,
},
),
# Case 3: Limited CPU resources affecting num_cpus calculation
Expand All @@ -89,6 +91,7 @@ class JoinTestCase:
"num_cpus": 0.25, # 2 CPUs * 25% / 2 aggregators
"memory": 2469606197,
"scheduling_strategy": "SPREAD",
"allow_out_of_order_execution": True,
},
),
# Case 4: Testing with many CPUs and partitions
Expand All @@ -106,6 +109,7 @@ class JoinTestCase:
"num_cpus": 0.25, # 32 CPUs * 25% / 32 aggregators
"memory": 1315333735,
"scheduling_strategy": "SPREAD",
"allow_out_of_order_execution": True,
},
),
# Case 5: Testing max aggregators cap (128 default)
Expand All @@ -123,6 +127,7 @@ class JoinTestCase:
"num_cpus": 0.5, # 256 CPUs * 25% / 128 aggregators
"memory": 2449473536,
"scheduling_strategy": "SPREAD",
"allow_out_of_order_execution": True,
},
),
# Case 6: Testing num_cpus derived from memory allocation
Expand All @@ -140,6 +145,7 @@ class JoinTestCase:
"num_cpus": 0.57, # ~2.5Gb / 4Gb = ~0.57
"memory": 2449473536,
"scheduling_strategy": "SPREAD",
"allow_out_of_order_execution": True,
},
),
# Case 7: No dataset size estimates available (fallback to default memory request)
Expand All @@ -158,6 +164,7 @@ class JoinTestCase:
# Default fallback of 2Gb
"memory": 1073741824,
"scheduling_strategy": "SPREAD",
"allow_out_of_order_execution": True,
},
),
],
Expand Down Expand Up @@ -254,6 +261,7 @@ class HashOperatorTestCase:
"num_cpus": 0.16,
"memory": 671088640,
"scheduling_strategy": "SPREAD",
"allow_out_of_order_execution": True,
},
),
# Case 2: Single partition produced
Expand All @@ -269,6 +277,7 @@ class HashOperatorTestCase:
"num_cpus": 0.25,
"memory": 1073741824,
"scheduling_strategy": "SPREAD",
"allow_out_of_order_execution": True,
},
),
# Case 3: Many CPUs
Expand All @@ -284,6 +293,7 @@ class HashOperatorTestCase:
"num_cpus": 0.25,
"memory": 1073741824,
"scheduling_strategy": "SPREAD",
"allow_out_of_order_execution": True,
},
),
# Case 4: Testing num_cpus derived from memory allocation
Expand All @@ -299,6 +309,7 @@ class HashOperatorTestCase:
"num_cpus": 0.16, # ~0.6Gb / 4Gb = ~0.16
"memory": 687865856,
"scheduling_strategy": "SPREAD",
"allow_out_of_order_execution": True,
},
),
# Case 6: No dataset size estimate inferred (fallback to default memory request)
Expand All @@ -314,6 +325,7 @@ class HashOperatorTestCase:
"num_cpus": 0.25,
"memory": 1073741824,
"scheduling_strategy": "SPREAD",
"allow_out_of_order_execution": True,
},
),
],
Expand Down Expand Up @@ -380,6 +392,7 @@ def test_hash_aggregate_operator_remote_args(
"num_cpus": 0.16,
"memory": 671088640,
"scheduling_strategy": "SPREAD",
"allow_out_of_order_execution": True,
},
),
# Case 2: Single partition produced
Expand All @@ -395,6 +408,7 @@ def test_hash_aggregate_operator_remote_args(
"num_cpus": 0.25,
"memory": 1073741824,
"scheduling_strategy": "SPREAD",
"allow_out_of_order_execution": True,
},
),
# Case 3: Many CPUs
Expand All @@ -410,6 +424,7 @@ def test_hash_aggregate_operator_remote_args(
"num_cpus": 0.25,
"memory": 1073741824,
"scheduling_strategy": "SPREAD",
"allow_out_of_order_execution": True,
},
),
# Case 4: Testing num_cpus derived from memory allocation
Expand All @@ -425,6 +440,7 @@ def test_hash_aggregate_operator_remote_args(
"num_cpus": 0.16, # ~0.6Gb / 4Gb = ~0.16
"memory": 687865856,
"scheduling_strategy": "SPREAD",
"allow_out_of_order_execution": True,
},
),
# Case 5: No dataset size estimate inferred (fallback to default memory request)
Expand All @@ -440,6 +456,7 @@ def test_hash_aggregate_operator_remote_args(
"num_cpus": 0.25,
"memory": 1073741824,
"scheduling_strategy": "SPREAD",
"allow_out_of_order_execution": True,
},
),
],
Expand Down Expand Up @@ -491,3 +508,61 @@ def test_hash_shuffle_operator_remote_args(
op._aggregator_pool._aggregator_ray_remote_args
== tc.expected_ray_remote_args
)


def test_aggregator_ray_remote_args_partial_override(ray_start_regular):
"""Test that partial override of aggregator_ray_remote_args retains default values.

This tests the behavior where a user provides only some values (e.g., num_cpus)
in aggregator_ray_remote_args_override, and the system should retain the default
values for other parameters (e.g., scheduling_strategy, allow_out_of_order_execution).
"""
logical_op_mock = MagicMock(LogicalOperator)
logical_op_mock.infer_metadata.return_value = BlockMetadata(
num_rows=None,
size_bytes=2 * GiB,
exec_stats=None,
input_files=None,
)
logical_op_mock.estimated_num_outputs.return_value = 16

op_mock = MagicMock(PhysicalOperator)
op_mock._output_dependencies = []
op_mock._logical_operators = [logical_op_mock]

# Patch the total cluster resources
with patch(
"ray.data._internal.execution.operators.hash_shuffle.ray.cluster_resources",
return_value={"CPU": 4.0, "memory": 32 * GiB},
):
# Create operator with partial override (only num_cpus)
op = HashAggregateOperator(
input_op=op_mock,
data_context=DataContext.get_current(),
aggregation_fns=[Count()],
key_columns=("id",),
aggregator_ray_remote_args_override={
"num_cpus": 0.5
}, # Only override num_cpus
)

# Verify that num_cpus was overridden
assert op._aggregator_pool._aggregator_ray_remote_args["num_cpus"] == 0.5

# Verify that default values are retained
assert (
op._aggregator_pool._aggregator_ray_remote_args["scheduling_strategy"]
== "SPREAD"
)
assert (
op._aggregator_pool._aggregator_ray_remote_args[
"allow_out_of_order_execution"
]
is True
)

# Verify that max_concurrency is still present
assert "max_concurrency" in op._aggregator_pool._aggregator_ray_remote_args

# Verify that memory is still present
assert "memory" in op._aggregator_pool._aggregator_ray_remote_args