Skip to content

Commit 95f86ab

Browse files
TimothySeahdstrodtman
authored andcommitted
[data][train] Refactor call_with_retry into shared library and use it to retry checkpoint upload (#56608)
This PR moves `call_with_retry` from `ray/data/_internal` to `ray/_private` so that it can be used in other libraries like Ray Train. It also adds a new `retry` decorator that wraps around `call_with_retry`. Note that I had to remove `*` from `call_with_retry`'s arguments to get the decorator to work on Python object methods because Python passes `self` as one of the `*args`. --------- Signed-off-by: Timothy Seah <tseah@anyscale.com> Signed-off-by: Douglas Strodtman <douglas@anyscale.com>
1 parent 9f185b4 commit 95f86ab

File tree

10 files changed

+206
-58
lines changed

10 files changed

+206
-58
lines changed

ci/lint/pydoclint-baseline.txt

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1241,9 +1241,6 @@ python/ray/data/_internal/util.py
12411241
DOC402: Function `make_async_gen` has "yield" statements, but the docstring does not have a "Yields" section
12421242
DOC404: Function `make_async_gen` yield type(s) in docstring not consistent with the return annotation. Return annotation exists, but docstring "yields" section does not exist or has 0 type(s).
12431243
DOC103: Method `RetryingPyFileSystemHandler.__init__`: Docstring arguments are different from function arguments. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Arguments in the function signature but not in the docstring: [retryable_errors: List[str]]. Arguments in the docstring but not in the function signature: [context: ].
1244-
DOC104: Function `call_with_retry`: Arguments are the same in the docstring and the function signature, but are in a different order.
1245-
DOC105: Function `call_with_retry`: Argument names match, but type hints in these args do not match: f, description, match, max_attempts, max_backoff_s
1246-
DOC201: Function `call_with_retry` does not have a return section in docstring
12471244
DOC104: Function `iterate_with_retry`: Arguments are the same in the docstring and the function signature, but are in a different order.
12481245
DOC105: Function `iterate_with_retry`: Argument names match, but type hints in these args do not match: iterable_factory, description, match, max_attempts, max_backoff_s
12491246
DOC001: Method `__init__` Potential formatting errors in docstring. Error message: No specification for "Args": ""

python/ray/_common/retry.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import functools
2+
import logging
3+
import random
4+
import time
5+
from typing import Any, Callable, List, Optional
6+
7+
logger = logging.getLogger(__name__)
8+
9+
10+
def call_with_retry(
11+
f: Callable,
12+
description: str,
13+
match: Optional[List[str]] = None,
14+
max_attempts: int = 10,
15+
max_backoff_s: int = 32,
16+
*args,
17+
**kwargs,
18+
) -> Any:
19+
"""Retry a function with exponential backoff.
20+
21+
Args:
22+
f: The function to retry.
23+
description: An imperative description of the function being retried. For
24+
example, "open the file".
25+
match: A list of strings to match in the exception message. If ``None``, any
26+
error is retried.
27+
max_attempts: The maximum number of attempts to retry.
28+
max_backoff_s: The maximum number of seconds to backoff.
29+
*args: Arguments to pass to the function.
30+
**kwargs: Keyword arguments to pass to the function.
31+
32+
Returns:
33+
The result of the function.
34+
"""
35+
# TODO: consider inverse match and matching exception type
36+
assert max_attempts >= 1, f"`max_attempts` must be positive. Got {max_attempts}."
37+
38+
for i in range(max_attempts):
39+
try:
40+
return f(*args, **kwargs)
41+
except Exception as e:
42+
exception_str = str(e)
43+
is_retryable = match is None or any(
44+
pattern in exception_str for pattern in match
45+
)
46+
if is_retryable and i + 1 < max_attempts:
47+
# Retry with binary exponential backoff with 20% random jitter.
48+
backoff = min(2**i, max_backoff_s) * (random.uniform(0.8, 1.2))
49+
logger.debug(
50+
f"Retrying {i+1} attempts to {description} after {backoff} seconds."
51+
)
52+
time.sleep(backoff)
53+
else:
54+
if is_retryable:
55+
logger.debug(
56+
f"Failed to {description} after {max_attempts} attempts. Raising."
57+
)
58+
else:
59+
logger.debug(
60+
f"Did not find a match for {exception_str}. Raising after {i+1} attempts."
61+
)
62+
raise e from None
63+
64+
65+
def retry(
66+
description: str,
67+
match: Optional[List[str]] = None,
68+
max_attempts: int = 10,
69+
max_backoff_s: int = 32,
70+
) -> Callable:
71+
"""Decorator-based version of call_with_retry."""
72+
73+
def decorator(func: Callable) -> Callable:
74+
@functools.wraps(func)
75+
def inner(*args, **kwargs):
76+
return call_with_retry(
77+
func, description, match, max_attempts, max_backoff_s, *args, **kwargs
78+
)
79+
80+
return inner
81+
82+
return decorator

python/ray/_common/tests/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ py_test_module_list(
1919
"test_formatters.py",
2020
"test_network_utils.py",
2121
"test_ray_option_utils.py",
22+
"test_retry.py",
2223
"test_signal_semaphore_utils.py",
2324
"test_signature.py",
2425
"test_utils.py",
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import sys
2+
3+
import pytest
4+
5+
from ray._common.retry import (
6+
call_with_retry,
7+
retry,
8+
)
9+
10+
11+
def test_call_with_retry_immediate_success_with_args():
12+
def func(a, b):
13+
return [a, b]
14+
15+
assert call_with_retry(func, "func", [], 1, 0, "a", "b") == ["a", "b"]
16+
17+
18+
def test_retry_immediate_success_with_object_args():
19+
class MyClass:
20+
@retry("func", [], 1, 0)
21+
def func(self, a, b):
22+
return [a, b]
23+
24+
assert MyClass().func("a", "b") == ["a", "b"]
25+
26+
27+
@pytest.mark.parametrize("use_decorator", [True, False])
28+
def test_retry_last_attempt_successful_with_appropriate_wait_time(
29+
monkeypatch, use_decorator
30+
):
31+
sleep_total = 0
32+
33+
def sleep(x):
34+
nonlocal sleep_total
35+
sleep_total += x
36+
37+
monkeypatch.setattr("time.sleep", sleep)
38+
monkeypatch.setattr("random.uniform", lambda a, b: 1)
39+
40+
pattern = "have not reached 4th attempt"
41+
call_count = 0
42+
43+
def func():
44+
nonlocal call_count
45+
call_count += 1
46+
if call_count == 4:
47+
return "success"
48+
raise ValueError(pattern)
49+
50+
args = ["func", [pattern], 4, 3]
51+
if use_decorator:
52+
assert retry(*args)(func)() == "success"
53+
else:
54+
assert call_with_retry(func, *args) == "success"
55+
assert sleep_total == 6 # 1 + 2 + 3
56+
57+
58+
@pytest.mark.parametrize("use_decorator", [True, False])
59+
def test_retry_unretryable_error(use_decorator):
60+
call_count = 0
61+
62+
def func():
63+
nonlocal call_count
64+
call_count += 1
65+
raise ValueError("unretryable error")
66+
67+
args = ["func", ["only retryable error"], 10, 0]
68+
with pytest.raises(ValueError, match="unretryable error"):
69+
if use_decorator:
70+
retry(*args)(func)()
71+
else:
72+
call_with_retry(func, *args)
73+
assert call_count == 1
74+
75+
76+
@pytest.mark.parametrize("use_decorator", [True, False])
77+
def test_retry_fail_all_attempts_retry_all_errors(use_decorator):
78+
call_count = 0
79+
80+
def func():
81+
nonlocal call_count
82+
call_count += 1
83+
raise ValueError(str(call_count))
84+
85+
args = ["func", None, 3, 0]
86+
with pytest.raises(ValueError):
87+
if use_decorator:
88+
retry(*args)(func)()
89+
else:
90+
call_with_retry(func, *args)
91+
assert call_count == 3
92+
93+
94+
if __name__ == "__main__":
95+
sys.exit(pytest.main(["-sv", __file__]))

python/ray/data/_internal/datasource/lance_datasource.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,8 @@
33

44
import numpy as np
55

6-
from ray.data._internal.util import (
7-
_check_import,
8-
call_with_retry,
9-
)
6+
from ray._common.retry import call_with_retry
7+
from ray.data._internal.util import _check_import
108
from ray.data.block import BlockMetadata
119
from ray.data.context import DataContext
1210
from ray.data.datasource.datasource import Datasource, ReadTask

python/ray/data/_internal/datasource/parquet_datasink.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
from pathlib import Path
33
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional
44

5+
from ray._common.retry import call_with_retry
56
from ray.data._internal.execution.interfaces import TaskContext
67
from ray.data._internal.planner.plan_write_op import WRITE_UUID_KWARG_NAME
78
from ray.data._internal.savemode import SaveMode
8-
from ray.data._internal.util import call_with_retry
99
from ray.data.block import Block, BlockAccessor
1010
from ray.data.datasource.file_based_datasource import _resolve_kwargs
1111
from ray.data.datasource.file_datasink import _FileDatasink

python/ray/data/_internal/util.py

Lines changed: 1 addition & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from packaging.version import parse as parse_version
3434

3535
import ray
36+
from ray._common.retry import call_with_retry
3637
from ray._private.arrow_utils import get_pyarrow_version
3738
from ray.data.context import DEFAULT_READ_OP_MIN_NUM_BLOCKS, WARN_PREFIX, DataContext
3839

@@ -1415,46 +1416,6 @@ def open_input_file(self, path: str) -> "pyarrow.NativeFile":
14151416
)
14161417

14171418

1418-
def call_with_retry(
1419-
f: Callable[[], Any],
1420-
description: str,
1421-
*,
1422-
match: Optional[List[str]] = None,
1423-
max_attempts: int = 10,
1424-
max_backoff_s: int = 32,
1425-
) -> Any:
1426-
"""Retry a function with exponential backoff.
1427-
1428-
Args:
1429-
f: The function to retry.
1430-
match: A list of strings to match in the exception message. If ``None``, any
1431-
error is retried.
1432-
description: An imperitive description of the function being retried. For
1433-
example, "open the file".
1434-
max_attempts: The maximum number of attempts to retry.
1435-
max_backoff_s: The maximum number of seconds to backoff.
1436-
"""
1437-
assert max_attempts >= 1, f"`max_attempts` must be positive. Got {max_attempts}."
1438-
1439-
for i in range(max_attempts):
1440-
try:
1441-
return f()
1442-
except Exception as e:
1443-
is_retryable = match is None or any(pattern in str(e) for pattern in match)
1444-
if is_retryable and i + 1 < max_attempts:
1445-
# Retry with binary expoential backoff with random jitter.
1446-
backoff = min((2 ** (i + 1)), max_backoff_s) * (random.random())
1447-
logger.debug(
1448-
f"Retrying {i+1} attempts to {description} after {backoff} seconds."
1449-
)
1450-
time.sleep(backoff)
1451-
else:
1452-
logger.debug(
1453-
f"Did not find a match for {str(e)}. Raising after {i+1} attempts."
1454-
)
1455-
raise e from None
1456-
1457-
14581419
def iterate_with_retry(
14591420
iterable_factory: Callable[[], Iterable],
14601421
description: str,

python/ray/data/datasource/file_datasink.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional
44
from urllib.parse import urlparse
55

6+
from ray._common.retry import call_with_retry
67
from ray._private.arrow_utils import add_creatable_buckets_param_if_s3_uri
78
from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder
89
from ray.data._internal.execution.interfaces import TaskContext
@@ -11,7 +12,6 @@
1112
from ray.data._internal.util import (
1213
RetryingPyFileSystem,
1314
_is_local_scheme,
14-
call_with_retry,
1515
)
1616
from ray.data.block import Block, BlockAccessor
1717
from ray.data.context import DataContext

python/ray/train/v2/_internal/constants.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,13 @@
1212
# The name of the file that is used to store the checkpoint manager snapshot.
1313
CHECKPOINT_MANAGER_SNAPSHOT_FILENAME = "checkpoint_manager_snapshot.json"
1414

15+
AWS_RETRYABLE_TOKENS = (
16+
"AWS Error SLOW_DOWN",
17+
"AWS Error INTERNAL_FAILURE",
18+
"AWS Error SERVICE_UNAVAILABLE",
19+
"AWS Error NETWORK_CONNECTION",
20+
"AWS Error UNKNOWN",
21+
)
1522

1623
# -----------------------------------------------------------------------
1724
# Environment variables used in the controller, workers, and state actor.

python/ray/train/v2/_internal/execution/context.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@
88
from typing import TYPE_CHECKING, Any, Dict, List, Optional
99

1010
import ray
11+
from ray._common.retry import retry
1112
from ray.actor import ActorHandle
1213
from ray.data import DataIterator, Dataset
1314
from ray.train._internal import session
1415
from ray.train._internal.session import _TrainingResult
16+
from ray.train.v2._internal.constants import AWS_RETRYABLE_TOKENS
1517
from ray.train.v2._internal.execution.checkpoint.sync_actor import SynchronizationActor
1618
from ray.train.v2._internal.execution.storage import StorageContext, delete_fs_path
1719
from ray.train.v2._internal.util import (
@@ -215,6 +217,8 @@ def _sync_checkpoint_dir_name_across_ranks(
215217
)
216218
)
217219

220+
# TODO: make retry configurable
221+
@retry(description="upload checkpoint", max_attempts=3, match=AWS_RETRYABLE_TOKENS)
218222
def _upload_checkpoint(
219223
self,
220224
checkpoint_dir_name: str,
@@ -334,10 +338,10 @@ def report(
334338
# Upload checkpoint, wait for turn, and report.
335339
if checkpoint_upload_mode == CheckpointUploadMode.SYNC:
336340
training_result = self._upload_checkpoint(
337-
checkpoint_dir_name,
338-
metrics,
339-
checkpoint,
340-
delete_local_checkpoint_after_upload,
341+
checkpoint_dir_name=checkpoint_dir_name,
342+
metrics=metrics,
343+
checkpoint=checkpoint,
344+
delete_local_checkpoint_after_upload=delete_local_checkpoint_after_upload,
341345
)
342346
self._wait_then_report(training_result, report_call_index)
343347

@@ -357,15 +361,18 @@ def _upload_checkpoint_and_report(
357361
) -> None:
358362
try:
359363
training_result = self._upload_checkpoint(
360-
checkpoint_dir_name,
361-
metrics,
362-
checkpoint,
363-
delete_local_checkpoint_after_upload,
364+
checkpoint_dir_name=checkpoint_dir_name,
365+
metrics=metrics,
366+
checkpoint=checkpoint,
367+
delete_local_checkpoint_after_upload=delete_local_checkpoint_after_upload,
364368
)
365369
self._wait_then_report(training_result, report_call_index)
366370
except Exception as e:
371+
# TODO: env var to disable eager raising
367372
logger.exception(
368-
"Async checkpoint upload failed - shutting down workers"
373+
"Checkpoint upload failed in the background thread. Raising eagerly "
374+
"to avoid training in a corrupted state with more potential progress "
375+
"lost due to checkpointing failures."
369376
)
370377
self.execution_context.training_thread_runner.get_exception_queue().put(
371378
construct_user_exception_with_traceback(e)

0 commit comments

Comments
 (0)