Skip to content

Commit 71ef8c7

Browse files
authored
Merge branch 'master' into srinathk10/format_batches
2 parents 35dba95 + eede46f commit 71ef8c7

File tree

10 files changed

+140
-69
lines changed

10 files changed

+140
-69
lines changed

ci/docker/forge.Dockerfile

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -61,16 +61,16 @@ wget -qO- https://astral.sh/uv/install.sh | sudo env UV_UNMANAGED_INSTALL="/usr/
6161

6262
mkdir -p /usr/local/python
6363
# Install Python 3.9 using uv
64-
uv python install --install-dir /usr/local/python 3.9
65-
uv python pin 3.9
64+
UV_PYTHON_VERSION=3.9
65+
uv python install --install-dir /usr/local/python "$UV_PYTHON_VERSION"
6666

6767
export UV_PYTHON_INSTALL_DIR=/usr/local/python
68-
# Make Python 3.9 from uv the default by creating symlinks
69-
PYTHON39_PATH=$(uv python find 3.9)
70-
echo $PYTHON39_PATH
71-
ln -s $PYTHON39_PATH /usr/local/bin/python3.9
72-
ln -s $PYTHON39_PATH /usr/local/bin/python3
73-
ln -s $PYTHON39_PATH /usr/local/bin/python
68+
# Make Python from uv the default by creating symlinks
69+
UV_PYTHON_BIN="$(uv python find --no-project "$UV_PYTHON_VERSION")"
70+
echo "uv python binary location: $UV_PYTHON_BIN"
71+
ln -s "$UV_PYTHON_BIN" "/usr/local/bin/python${UV_PYTHON_VERSION}"
72+
ln -s "$UV_PYTHON_BIN" /usr/local/bin/python3
73+
ln -s "$UV_PYTHON_BIN" /usr/local/bin/python
7474

7575
# As a convention, we pin all python packages to a specific version. This
7676
# is to to make sure we can control version upgrades through code changes.

python/ray/_private/test_utils.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1065,20 +1065,6 @@ def fetch_prometheus_metric_timeseries(
10651065
return samples_by_name
10661066

10671067

1068-
def raw_metrics(info: RayContext) -> Dict[str, List[Any]]:
1069-
"""Return prometheus metrics from a RayContext
1070-
1071-
Args:
1072-
info: Ray context returned from ray.init()
1073-
1074-
Returns:
1075-
Dict from metric name to a list of samples for the metrics
1076-
"""
1077-
metrics_page = "localhost:{}".format(info.address_info["metrics_export_port"])
1078-
print("Fetch metrics from", metrics_page)
1079-
return fetch_prometheus_metrics([metrics_page])
1080-
1081-
10821068
def raw_metric_timeseries(
10831069
info: RayContext, result: PrometheusTimeseries
10841070
) -> Dict[str, List[Any]]:

python/ray/data/grouped_data.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections.abc import Iterator as IteratorABC
12
from functools import partial
23
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union
34

@@ -538,7 +539,10 @@ def std(
538539

539540

540541
def _apply_udf_to_groups(
541-
udf: Callable[[DataBatch, ...], DataBatch],
542+
udf: Union[
543+
Callable[[DataBatch, ...], DataBatch],
544+
Callable[[DataBatch, ...], Iterator[DataBatch]],
545+
],
542546
block: Block,
543547
keys: List[str],
544548
batch_format: Optional[str],
@@ -548,7 +552,8 @@ def _apply_udf_to_groups(
548552
"""Apply UDF to groups of rows having the same set of values of the specified
549553
columns (keys).
550554
551-
NOTE: This function is defined at module level to avoid capturing closures and make it serializable."""
555+
NOTE: This function is defined at module level to avoid capturing closures and make it serializable.
556+
"""
552557
block_accessor = BlockAccessor.for_block(block)
553558

554559
boundaries = block_accessor._get_group_boundaries_sorted(keys)
@@ -560,7 +565,17 @@ def _apply_udf_to_groups(
560565
# Convert corresponding block of each group to batch format here,
561566
# because the block format here can be different from batch format
562567
# (e.g. block is Arrow format, and batch is NumPy format).
563-
yield udf(group_block_accessor.to_batch_format(batch_format), *args, **kwargs)
568+
result = udf(
569+
group_block_accessor.to_batch_format(batch_format), *args, **kwargs
570+
)
571+
572+
# Check if the UDF returned an iterator/generator.
573+
if isinstance(result, IteratorABC):
574+
# If so, yield each item from the iterator.
575+
yield from result
576+
else:
577+
# Otherwise, yield the single result.
578+
yield result
564579

565580

566581
# Backwards compatibility alias.

python/ray/data/tests/test_groupby_e2e.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import itertools
22
import random
33
import time
4-
from typing import Optional
4+
from typing import Iterator, Optional
55

66
import numpy as np
77
import pandas as pd
@@ -1142,6 +1142,40 @@ def func(x, y):
11421142
assert "MapBatches(func)" in ds.__repr__()
11431143

11441144

1145+
def test_map_groups_generator_udf(ray_start_regular_shared_2_cpus):
1146+
"""
1147+
Tests that map_groups supports UDFs that return generators (iterators).
1148+
"""
1149+
ds = ray.data.from_items(
1150+
[
1151+
{"group": 1, "data": 10},
1152+
{"group": 1, "data": 20},
1153+
{"group": 2, "data": 30},
1154+
]
1155+
)
1156+
1157+
def generator_udf(df: pd.DataFrame) -> Iterator[pd.DataFrame]:
1158+
# For each group, yield two DataFrames.
1159+
# 1. A DataFrame where 'data' is multiplied by 2.
1160+
yield df.assign(data=df["data"] * 2)
1161+
# 2. A DataFrame where 'data' is multiplied by 3.
1162+
yield df.assign(data=df["data"] * 3)
1163+
1164+
# Apply the generator UDF to the grouped data.
1165+
result_ds = ds.groupby("group").map_groups(generator_udf)
1166+
1167+
# The final dataset should contain all results from all yields.
1168+
# Group 1 -> data: [20, 40] and [30, 60]
1169+
# Group 2 -> data: [60] and [90]
1170+
expected_data = sorted([20, 40, 30, 60, 60, 90])
1171+
1172+
# Collect and sort the actual data to ensure correctness regardless of order.
1173+
actual_data = sorted([row["data"] for row in result_ds.take_all()])
1174+
1175+
assert actual_data == expected_data
1176+
assert result_ds.count() == 6
1177+
1178+
11451179
if __name__ == "__main__":
11461180
import sys
11471181

python/ray/data/tests/test_join.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import ray
99
from ray._private.arrow_utils import get_pyarrow_version
1010
from ray.data._internal.logical.operators.join_operator import JoinType
11-
from ray.data._internal.util import MiB
11+
from ray.data._internal.util import MiB, rows_same
1212
from ray.data.context import DataContext
1313
from ray.data.dataset import Dataset
1414
from ray.exceptions import RayTaskError
@@ -263,11 +263,7 @@ def test_simple_self_join(ray_start_regular_shared_2_cpus, left_suffix, right_su
263263

264264
assert 'Field "double" exists 2 times' in str(exc_info.value.cause)
265265
else:
266-
267-
joined_pd = pd.DataFrame(joined.take_all())
268-
269-
# Sort resulting frame and reset index (to be able to compare with expected one)
270-
joined_pd_sorted = joined_pd.sort_values(by=["id"]).reset_index(drop=True)
266+
joined_pd = joined.to_pandas()
271267

272268
# Join using Pandas (to assert against)
273269
expected_pd = doubles_pd.join(
@@ -278,7 +274,7 @@ def test_simple_self_join(ray_start_regular_shared_2_cpus, left_suffix, right_su
278274
rsuffix=right_suffix,
279275
).reset_index(drop=True)
280276

281-
pd.testing.assert_frame_equal(expected_pd, joined_pd_sorted)
277+
assert rows_same(expected_pd, joined_pd), "Expected contents to be same"
282278

283279

284280
def test_invalid_join_config(ray_start_regular_shared_2_cpus):

python/ray/tests/test_memory_pressure.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
)
1414
from ray._private.grpc_utils import init_grpc_channel
1515
from ray._private.state_api_test_utils import verify_failed_task
16-
from ray._private.test_utils import raw_metrics
16+
from ray._private.test_utils import PrometheusTimeseries, raw_metric_timeseries
1717
from ray._private.utils import get_used_memory
1818
from ray.util.state.state_manager import StateDataSourceClient
1919

@@ -118,8 +118,10 @@ def get_additional_bytes_to_reach_memory_usage_pct(pct: float) -> int:
118118
return bytes_needed
119119

120120

121-
def has_metric_tagged_with_value(addr, tag, value) -> bool:
122-
metrics = raw_metrics(addr)
121+
def has_metric_tagged_with_value(
122+
addr, tag, value, timeseries: PrometheusTimeseries
123+
) -> bool:
124+
metrics = raw_metric_timeseries(addr, timeseries)
123125
for name, samples in metrics.items():
124126
for sample in samples:
125127
if tag in set(sample.labels.values()) and sample.value == value:
@@ -145,13 +147,15 @@ def test_restartable_actor_throws_oom_error(ray_with_memory_monitor, restartable
145147
with pytest.raises(ray.exceptions.OutOfMemoryError):
146148
ray.get(leaker.allocate.remote(bytes_to_alloc, memory_monitor_refresh_ms * 3))
147149

150+
timeseries = PrometheusTimeseries()
148151
wait_for_condition(
149152
has_metric_tagged_with_value,
150153
timeout=10,
151154
retry_interval_ms=100,
152155
addr=addr,
153156
tag="MemoryManager.ActorEviction.Total",
154157
value=2.0 if restartable else 1.0,
158+
timeseries=timeseries,
155159
)
156160

157161
wait_for_condition(
@@ -161,6 +165,7 @@ def test_restartable_actor_throws_oom_error(ray_with_memory_monitor, restartable
161165
addr=addr,
162166
tag="Leaker.__init__",
163167
value=2.0 if restartable else 1.0,
168+
timeseries=timeseries,
164169
)
165170

166171

@@ -180,13 +185,15 @@ def test_restartable_actor_oom_retry_off_throws_oom_error(
180185
with pytest.raises(ray.exceptions.OutOfMemoryError) as _:
181186
ray.get(leaker.allocate.remote(bytes_to_alloc, memory_monitor_refresh_ms * 3))
182187

188+
timeseries = PrometheusTimeseries()
183189
wait_for_condition(
184190
has_metric_tagged_with_value,
185191
timeout=10,
186192
retry_interval_ms=100,
187193
addr=addr,
188194
tag="MemoryManager.ActorEviction.Total",
189195
value=2.0,
196+
timeseries=timeseries,
190197
)
191198
wait_for_condition(
192199
has_metric_tagged_with_value,
@@ -195,6 +202,7 @@ def test_restartable_actor_oom_retry_off_throws_oom_error(
195202
addr=addr,
196203
tag="Leaker.__init__",
197204
value=2.0,
205+
timeseries=timeseries,
198206
)
199207

200208

@@ -210,13 +218,15 @@ def test_non_retryable_task_killed_by_memory_monitor_with_oom_error(
210218
with pytest.raises(ray.exceptions.OutOfMemoryError) as _:
211219
ray.get(allocate_memory.options(max_retries=0).remote(bytes_to_alloc))
212220

221+
timeseries = PrometheusTimeseries()
213222
wait_for_condition(
214223
has_metric_tagged_with_value,
215224
timeout=10,
216225
retry_interval_ms=100,
217226
addr=addr,
218227
tag="MemoryManager.TaskEviction.Total",
219228
value=1.0,
229+
timeseries=timeseries,
220230
)
221231
wait_for_condition(
222232
has_metric_tagged_with_value,
@@ -225,6 +235,7 @@ def test_non_retryable_task_killed_by_memory_monitor_with_oom_error(
225235
addr=addr,
226236
tag="allocate_memory",
227237
value=1.0,
238+
timeseries=timeseries,
228239
)
229240

230241

@@ -372,13 +383,15 @@ def test_task_oom_no_oom_retry_fails_immediately(
372383
)
373384
)
374385

386+
timeseries = PrometheusTimeseries()
375387
wait_for_condition(
376388
has_metric_tagged_with_value,
377389
timeout=10,
378390
retry_interval_ms=100,
379391
addr=addr,
380392
tag="MemoryManager.TaskEviction.Total",
381393
value=1.0,
394+
timeseries=timeseries,
382395
)
383396
wait_for_condition(
384397
has_metric_tagged_with_value,
@@ -387,6 +400,7 @@ def test_task_oom_no_oom_retry_fails_immediately(
387400
addr=addr,
388401
tag="allocate_memory",
389402
value=1.0,
403+
timeseries=timeseries,
390404
)
391405

392406

@@ -411,13 +425,15 @@ def test_task_oom_only_uses_oom_retry(
411425
)
412426
)
413427

428+
timeseries = PrometheusTimeseries()
414429
wait_for_condition(
415430
has_metric_tagged_with_value,
416431
timeout=10,
417432
retry_interval_ms=100,
418433
addr=addr,
419434
tag="MemoryManager.TaskEviction.Total",
420435
value=task_oom_retries + 1,
436+
timeseries=timeseries,
421437
)
422438
wait_for_condition(
423439
has_metric_tagged_with_value,
@@ -426,6 +442,7 @@ def test_task_oom_only_uses_oom_retry(
426442
addr=addr,
427443
tag="allocate_memory",
428444
value=task_oom_retries + 1,
445+
timeseries=timeseries,
429446
)
430447

431448

@@ -502,6 +519,7 @@ def infinite_retry_task():
502519
time.sleep(5)
503520

504521
with ray.init() as addr:
522+
timeseries = PrometheusTimeseries()
505523
with pytest.raises(ray.exceptions.OutOfMemoryError) as _:
506524
ray.get(infinite_retry_task.remote())
507525

@@ -512,6 +530,7 @@ def infinite_retry_task():
512530
addr=addr,
513531
tag="MemoryManager.TaskEviction.Total",
514532
value=1.0,
533+
timeseries=timeseries,
515534
)
516535

517536

0 commit comments

Comments
 (0)