Skip to content

Commit 5dfada7

Browse files
JasonLi1909peterxcli
authored andcommitted
[train] Rename DatasetsSetupCallback to DatasetsCallback (ray-project#59423)
PR ray-project#58325 adds shutdown and abort hooks to enhance resource-cleanup logic in DatasetsSetupCallback, the callback’s responsibilities have expanded beyond initial setup. Accordingly, this PR renames it to DatasetsCallback for better alignment with its behavior. Signed-off-by: JasonLi1909 <jasli1909@gmail.com> Signed-off-by: peterxcli <peterxcli@gmail.com>
1 parent 6517392 commit 5dfada7

File tree

5 files changed

+13
-15
lines changed

5 files changed

+13
-15
lines changed

python/ray/train/v2/_internal/callbacks/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
from .accelerators import AcceleratorSetupCallback
22
from .backend_setup import BackendSetupCallback
3-
from .datasets import DatasetsSetupCallback
3+
from .datasets import DatasetsCallback
44
from .state_manager import StateManagerCallback
55
from .tpu_reservation_callback import TPUReservationCallback
66
from .working_dir_setup import WorkingDirectorySetupCallback
77

88
__all__ = [
99
"AcceleratorSetupCallback",
1010
"BackendSetupCallback",
11-
"DatasetsSetupCallback",
11+
"DatasetsCallback",
1212
"StateManagerCallback",
1313
"TPUReservationCallback",
1414
"WorkingDirectorySetupCallback",

python/ray/train/v2/_internal/callbacks/datasets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def get_dataset_shard(self, dataset_info: DatasetShardMetadata) -> DataIterator:
4646
return self._dataset_iterators[dataset_info.dataset_name]
4747

4848

49-
class DatasetsSetupCallback(WorkerGroupCallback, ControllerCallback):
49+
class DatasetsCallback(WorkerGroupCallback, ControllerCallback):
5050
"""A callback for managing Ray Datasets for the worker group."""
5151

5252
def __init__(self, train_run_context: TrainRunContext):

python/ray/train/v2/api/data_parallel_trainer.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from ray.train.v2._internal.callbacks import (
2828
AcceleratorSetupCallback,
2929
BackendSetupCallback,
30-
DatasetsSetupCallback,
30+
DatasetsCallback,
3131
TPUReservationCallback,
3232
WorkingDirectorySetupCallback,
3333
)
@@ -203,9 +203,7 @@ def _create_default_callbacks(self) -> List[RayTrainCallback]:
203203
self.backend_config, self.scaling_config
204204
)
205205
backend_setup_callback = BackendSetupCallback(self.backend_config)
206-
datasets_callback = DatasetsSetupCallback(
207-
train_run_context=self.train_run_context
208-
)
206+
datasets_callback = DatasetsCallback(train_run_context=self.train_run_context)
209207
tpu_reservation_setup_callback = TPUReservationCallback()
210208
placement_group_cleaner_callback = PlacementGroupCleanerCallback()
211209
callbacks.extend(

python/ray/train/v2/tests/test_data_integration.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
)
1717
from ray.data._internal.iterator.stream_split_iterator import StreamSplitDataIterator
1818
from ray.data.tests.conftest import restore_data_context # noqa: F401
19-
from ray.train.v2._internal.callbacks.datasets import DatasetsSetupCallback
19+
from ray.train.v2._internal.callbacks.datasets import DatasetsCallback
2020
from ray.train.v2._internal.data_integration.interfaces import DatasetShardMetadata
2121
from ray.train.v2._internal.execution.worker_group.worker_group import (
2222
WorkerGroupContext,
@@ -92,7 +92,7 @@ def test_data_config_validation():
9292

9393

9494
def test_datasets_callback(ray_start_4_cpus):
95-
"""Check that the `DatasetsSetupCallback` correctly configures the
95+
"""Check that the `DatasetsCallback` correctly configures the
9696
dataset shards and execution options."""
9797
NUM_WORKERS = 2
9898

@@ -121,7 +121,7 @@ def test_datasets_callback(ray_start_4_cpus):
121121
)
122122
worker_group._start()
123123

124-
callback = DatasetsSetupCallback(train_run_context)
124+
callback = DatasetsCallback(train_run_context)
125125
dataset_manager_for_each_worker = callback.before_init_train_context(
126126
worker_group.get_workers()
127127
)["dataset_shard_provider"]

python/ray/train/v2/tests/test_data_resource_cleanup.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
SplitCoordinator,
1313
_DatasetWrapper,
1414
)
15-
from ray.train.v2._internal.callbacks.datasets import DatasetsSetupCallback
15+
from ray.train.v2._internal.callbacks.datasets import DatasetsCallback
1616
from ray.train.v2._internal.execution.worker_group import (
1717
WorkerGroup,
1818
WorkerGroupContext,
@@ -23,7 +23,7 @@
2323

2424

2525
def test_datasets_callback_multiple_datasets(ray_start_4_cpus):
26-
"""Test that the DatasetsSetupCallback properly collects the coordinator actors for multiple datasets"""
26+
"""Test that the DatasetsCallback properly collects the coordinator actors for multiple datasets"""
2727
# Start worker group
2828
worker_group_context = WorkerGroupContext(
2929
run_attempt_id="test",
@@ -49,7 +49,7 @@ def test_datasets_callback_multiple_datasets(ray_start_4_cpus):
4949
datasets=datasets, dataset_config=dataset_config
5050
)
5151

52-
callback = DatasetsSetupCallback(train_run_context)
52+
callback = DatasetsCallback(train_run_context)
5353
callback.before_init_train_context(wg.get_workers())
5454

5555
# Two coordinator actors, one for each sharded dataset
@@ -58,7 +58,7 @@ def test_datasets_callback_multiple_datasets(ray_start_4_cpus):
5858

5959

6060
def test_after_worker_group_abort():
61-
callback = DatasetsSetupCallback(create_dummy_run_context())
61+
callback = DatasetsCallback(create_dummy_run_context())
6262

6363
# Mock SplitCoordinator shutdown_executor method
6464
coord_mock = create_autospec(SplitCoordinator)
@@ -79,7 +79,7 @@ def test_after_worker_group_abort():
7979

8080

8181
def test_after_worker_group_shutdown():
82-
callback = DatasetsSetupCallback(create_dummy_run_context())
82+
callback = DatasetsCallback(create_dummy_run_context())
8383

8484
# Mock SplitCoordinator shutdown_executor method
8585
coord_mock = create_autospec(SplitCoordinator)

0 commit comments

Comments
 (0)