Skip to content

Commit 90ce0d6

Browse files
committed
lint
Signed-off-by: iamjustinhsu <jhsu@anyscale.com>
1 parent d0f3b70 commit 90ce0d6

File tree

7 files changed

+138
-239
lines changed

7 files changed

+138
-239
lines changed

python/ray/data/_internal/block_batching/iter_batches.py

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import collections
2+
import time
23
from contextlib import contextmanager, nullcontext
34
from typing import Any, Callable, Dict, Iterator, Optional
45

@@ -15,7 +16,7 @@
1516
)
1617
from ray.data._internal.execution.interfaces.ref_bundle import RefBundle
1718
from ray.data._internal.memory_tracing import trace_deallocation
18-
from ray.data._internal.stats import DatasetStats, StatsManager
19+
from ray.data._internal.stats import DatasetStats, _StatsManager
1920
from ray.data._internal.util import make_async_gen
2021
from ray.data.block import Block, DataBatch
2122
from ray.data.context import DataContext
@@ -92,6 +93,8 @@ class BatchIterator:
9293
formatting to be overlapped with the UDF. Defaults to 1.
9394
"""
9495

96+
UPDATE_METRICS_INTERVAL_S: float = 5.0
97+
9598
def __init__(
9699
self,
97100
ref_bundles: Iterator[RefBundle],
@@ -136,6 +139,7 @@ def __init__(
136139
else WaitBlockPrefetcher()
137140
)
138141
self._yielded_first_batch = False
142+
self._metrics_last_updated: float = 0.0
139143

140144
def _prefetch_blocks(
141145
self, ref_bundles: Iterator[RefBundle]
@@ -219,34 +223,29 @@ def _iter_batches(self) -> Iterator[DataBatch]:
219223
preserve_ordering=False,
220224
)
221225

222-
with self._epoch_context():
223-
while True:
224-
with self.get_next_batch_context():
225-
try:
226-
batch = next(async_batch_iter)
227-
except StopIteration:
228-
break
229-
with self.yield_batch_context(batch):
230-
yield batch.data
226+
self.before_epoch_start()
227+
228+
while True:
229+
with self.get_next_batch_context():
230+
try:
231+
batch = next(async_batch_iter)
232+
except StopIteration:
233+
break
234+
with self.yield_batch_context(batch):
235+
yield batch.data
236+
237+
self.after_epoch_end()
231238

232239
def __iter__(self) -> Iterator[DataBatch]:
233240
return self._iter_batches()
234241

235-
@contextmanager
236-
def _epoch_context(self):
237-
"""Context manager for epoch lifecycle: setup and cleanup.
238-
239-
Ensures proper initialization before iteration and cleanup after,
240-
even if the iteration is interrupted or fails.
241-
"""
242-
# Setup: Initialize epoch state
242+
def before_epoch_start(self):
243243
self._yielded_first_batch = False
244-
StatsManager.register_dataset_tag(self._dataset_tag)
245-
try:
246-
yield
247-
finally:
248-
# Cleanup: Clear iteration metrics after epoch completes or fails
249-
StatsManager.clear_iteration_metrics(self._dataset_tag)
244+
245+
def after_epoch_end(self):
246+
_StatsManager.update_iteration_metrics(
247+
self._stats, self._dataset_tag, force_update=True
248+
)
250249

251250
@contextmanager
252251
def get_next_batch_context(self):
@@ -271,7 +270,10 @@ def get_next_batch_context(self):
271270
def yield_batch_context(self, batch: Batch):
272271
with self._stats.iter_user_s.timer() if self._stats else nullcontext():
273272
yield
274-
StatsManager.update_iteration_metrics(self._stats, self._dataset_tag)
273+
now = time.time()
274+
if (now - self._metrics_last_updated) > self.UPDATE_METRICS_INTERVAL_S:
275+
_StatsManager.update_iteration_metrics(self._stats, self._dataset_tag)
276+
self._metrics_last_updated = now
275277

276278

277279
def _format_in_threadpool(

python/ray/data/_internal/execution/streaming_executor.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
)
4545
from ray.data._internal.metadata_exporter import Topology as TopologyMetadata
4646
from ray.data._internal.progress_bar import ProgressBar
47-
from ray.data._internal.stats import DatasetStats, StatsManager, Timer
47+
from ray.data._internal.stats import DatasetStats, Timer, _StatsManager
4848
from ray.data.context import OK_PREFIX, WARN_PREFIX, DataContext
4949
from ray.util.debug import log_once
5050
from ray.util.metrics import Gauge
@@ -71,6 +71,8 @@ class StreamingExecutor(Executor, threading.Thread):
7171
a way that maximizes throughput under resource constraints.
7272
"""
7373

74+
UPDATE_METRICS_INTERVAL_S: float = 5.0
75+
7476
def __init__(
7577
self,
7678
data_context: DataContext,
@@ -115,6 +117,8 @@ def __init__(
115117
register_dataset_logger(self._dataset_id)
116118
)
117119

120+
self._metrics_last_updated: float = 0.0
121+
118122
self._sched_loop_duration_s = Gauge(
119123
"data_sched_loop_duration_s",
120124
description="Duration of the scheduling loop in seconds",
@@ -223,7 +227,7 @@ def execute(
223227
op_to_id = {
224228
op: self._get_operator_id(op, i) for i, op in enumerate(self._topology)
225229
}
226-
StatsManager.register_dataset_to_stats_actor(
230+
_StatsManager.register_dataset_to_stats_actor(
227231
self._dataset_id,
228232
self._get_operator_tags(),
229233
TopologyMetadata.create_topology_metadata(dag, op_to_id),
@@ -271,9 +275,6 @@ def shutdown(self, force: bool, exception: Optional[Exception] = None):
271275
else DatasetState.FAILED.name,
272276
force_update=True,
273277
)
274-
# Once Dataset execution completes, mark it as complete
275-
# and remove last cached execution stats.
276-
StatsManager.clear_last_execution_stats(self._dataset_id)
277278
# Freeze the stats and save it.
278279
self._final_stats = self._generate_stats()
279280
stats_summary_string = self._final_stats.to_summary().to_string(
@@ -660,13 +661,19 @@ def _get_state_dict(self, state):
660661
}
661662

662663
def _update_stats_metrics(self, state: str, force_update: bool = False):
663-
StatsManager.update_execution_metrics(
664-
self._dataset_id,
665-
[op.metrics for op in self._topology],
666-
self._get_operator_tags(),
667-
self._get_state_dict(state=state),
668-
force_update=force_update,
669-
)
664+
now = time.time()
665+
if (
666+
force_update
667+
or (now - self._metrics_last_updated) > self.UPDATE_METRICS_INTERVAL_S
668+
):
669+
_StatsManager.update_execution_metrics(
670+
self._dataset_id,
671+
[op.metrics for op in self._topology],
672+
self._get_operator_tags(),
673+
self._get_state_dict(state=state),
674+
force_update=force_update,
675+
)
676+
self._metrics_last_updated = now
670677

671678
def _use_rich_progress(self):
672679
return self._data_context.enable_rich_progress_bars

python/ray/data/_internal/iterator/stream_split_iterator.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,6 @@ def gen_blocks() -> Iterator[RefBundle]:
8181
future: ObjectRef[
8282
Optional[ObjectRef[Block]]
8383
] = self._coord_actor.get.remote(cur_epoch, self._output_split_idx)
84-
8584
while True:
8685
block_ref_and_md: Optional[RefBundle] = ray.get(future)
8786
if not block_ref_and_md:

0 commit comments

Comments
 (0)