Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 30 additions & 2 deletions python/ray/train/v2/api/train_fn_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,16 @@ def train_func(config):
validate_config: Configuration passed to the validate_fn. Can contain info
like the validation dataset.
"""
from ray.tune.trainable.trainable_fn_utils import _in_tune_session

if _in_tune_session():
raise DeprecationWarning(
"`ray.train.report` is deprecated when running in a function "
"passed to Ray Tune. Please use `ray.tune.report` instead. "
"See this issue for more context: "
"https://github.com/ray-project/ray/issues/49454"
)

if delete_local_checkpoint_after_upload is None:
delete_local_checkpoint_after_upload = (
checkpoint_upload_mode._default_delete_local_checkpoint_after_upload()
Expand Down Expand Up @@ -130,8 +140,16 @@ def get_context() -> TrainContext:

See the :class:`~ray.train.TrainContext` API reference to see available methods.
"""
# TODO: Return a dummy train context on the controller and driver process
# instead of raising an exception if the train context does not exist.
from ray.tune.trainable.trainable_fn_utils import _in_tune_session

if _in_tune_session():
raise DeprecationWarning(
"`ray.train.get_context` is deprecated when running in a function "
"passed to Ray Tune. Please use `ray.tune.get_context` instead. "
"See this issue for more context: "
"https://github.com/ray-project/ray/issues/49454"
)

return get_train_fn_utils().get_context()


Expand Down Expand Up @@ -179,6 +197,16 @@ def train_func(config):
Checkpoint object if the session is currently being resumed.
Otherwise, return None.
"""
from ray.tune.trainable.trainable_fn_utils import _in_tune_session

if _in_tune_session():
raise DeprecationWarning(
"`ray.train.get_checkpoint` is deprecated when running in a function "
"passed to Ray Tune. Please use `ray.tune.get_checkpoint` instead. "
"See this issue for more context: "
"https://github.com/ray-project/ray/issues/49454"
)

return get_train_fn_utils().get_checkpoint()


Expand Down
27 changes: 23 additions & 4 deletions python/ray/tune/tests/test_api_migrations.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import functools
import importlib
import sys
import warnings

Expand All @@ -9,31 +11,48 @@
from ray.util.annotations import RayDeprecationWarning


@pytest.fixture(autouse=True)
def enable_v2(monkeypatch):
monkeypatch.setenv("RAY_TRAIN_V2_ENABLED", "1")
importlib.reload(ray.train)
yield


@pytest.fixture(autouse=True)
def enable_v2_migration_deprecation_messages(monkeypatch):
monkeypatch.setenv(ENABLE_V2_MIGRATION_WARNINGS_ENV_VAR, "1")
yield
monkeypatch.delenv(ENABLE_V2_MIGRATION_WARNINGS_ENV_VAR)


def test_trainable_fn_utils(tmp_path):
@pytest.mark.parametrize("v2_enabled", [False, True])
def test_trainable_fn_utils(tmp_path, monkeypatch, v2_enabled):
monkeypatch.setenv("RAY_TRAIN_V2_ENABLED", str(int(v2_enabled)))
importlib.reload(ray.train)

dummy_checkpoint_dir = tmp_path.joinpath("dummy")
dummy_checkpoint_dir.mkdir()

asserting_context = (
functools.partial(pytest.raises, DeprecationWarning)
if v2_enabled
else functools.partial(pytest.warns, RayDeprecationWarning)
)

def tune_fn(config):
with pytest.warns(RayDeprecationWarning, match="ray.tune.get_checkpoint"):
with asserting_context(match="ray.tune.get_checkpoint"):
ray.train.get_checkpoint()

with warnings.catch_warnings():
ray.tune.get_checkpoint()

with pytest.warns(RayDeprecationWarning, match="ray.tune.get_context"):
with asserting_context(match="ray.tune.get_context"):
ray.train.get_context()

with warnings.catch_warnings():
ray.tune.get_context()

with pytest.warns(RayDeprecationWarning, match="ray.tune.report"):
with asserting_context(match="ray.tune.report"):
ray.train.report({"a": 1})

with warnings.catch_warnings():
Expand Down