Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 10 additions & 17 deletions python/ray/train/v2/_internal/execution/worker_group/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,25 +106,18 @@ def shutdown(self):


def _shutdown_workers(workers: List[Worker], patience_s: float = 5):
# Run the worker shutdown logic on each of the workers. This should
# be a non-blocking call to realize forceful shutdown after patience_s.
_ = [w.actor.shutdown.remote() for w in workers]
"""Shuts down workers after allowing a maximum of patience_s seconds for shutdown hooks to run."""
if patience_s < 0:
raise ValueError("Invalid patience_s: must be non-negative")

done_refs = [w.actor.shutdown.remote() for w in workers]

logger.debug(f"Shutting down {len(workers)} workers.")
if patience_s <= 0:
for worker in workers:
ray.kill(worker.actor)
else:
done_refs = [w.actor.__ray_terminate__.remote() for w in workers]
# Wait for actors to die gracefully.
_, not_done = ray.wait(
done_refs, num_returns=len(done_refs), timeout=patience_s
)
if not_done:
logger.debug("Graceful termination failed. Falling back to force kill.")
# If all actors are not able to die gracefully, then kill them.
for worker in workers:
ray.kill(worker.actor)

ray.wait(done_refs, num_returns=len(done_refs), timeout=patience_s)

for worker in workers:
ray.kill(worker.actor)


def _shutdown_sync_actor(sync_actor: SynchronizationActor):
Expand Down
73 changes: 72 additions & 1 deletion python/ray/train/v2/tests/test_worker_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@
WorkerGroupContext,
WorkerGroupState,
)
from ray.train.v2.api.config import RunConfig
from ray.train.v2._internal.util import ObjectRefWrapper
from ray.train.v2.api.config import RunConfig, ScalingConfig
from ray.train.v2.tests.util import DummyObjectRefWrapper, create_dummy_run_context
from ray.util.state import list_actors

pytestmark = pytest.mark.usefixtures("mock_runtime_context")

Expand Down Expand Up @@ -161,6 +163,75 @@ def hanging_task(*args, **kwargs):
wg._start()


def test_zombie_actor_termination(ray_start_4_cpus):
"""This test checks that RayTrainWorker actors are terminated correctly even if python garbage collection hangs on actor shutdown."""
NUM_WORKERS = 4

def is_process_alive(pid: int) -> bool:
try:
os.kill(pid, 0)
except ProcessLookupError:
return False
except PermissionError:
return True
else:
return True

class Node:
def __init__(self, name):
self.name = name
self.other = None

def __del__(self):
# Simulate hang during garbage collection
while True:
time.sleep(1)

def train_fn():
# Create a circular reference to delay garbage collection
a, b = Node("a"), Node("b")
a.other = b
b.other = a

train_fn_ref = ObjectRefWrapper(train_fn)

train_run_context = create_dummy_run_context(
scaling_config=ScalingConfig(num_workers=NUM_WORKERS)
)
worker_group_context = _default_worker_group_context(
train_fn_ref=train_fn_ref,
num_workers=NUM_WORKERS,
)

# Starts the worker group and runs the train function
worker_group = WorkerGroup.create(
train_run_context=train_run_context,
worker_group_context=worker_group_context,
callbacks=[],
)

train_worker_pids = [
actor.pid
for actor in list_actors()
if actor.class_name == RayTrainWorker.__name__ and actor.state == "ALIVE"
]

assert len(train_worker_pids) == NUM_WORKERS

worker_group.shutdown()

# ray.kill is async, allow some time for the processes to terminate
TIMEOUT_S = 5
deadline = time.monotonic() + TIMEOUT_S
remaining = set(train_worker_pids)
while remaining and time.monotonic() < deadline:
remaining = {pid for pid in remaining if is_process_alive(pid)}
if remaining:
time.sleep(0.1)

assert not remaining


def test_insufficient_cluster_resources_startup_failure(monkeypatch):
"""Test that WorkerGroup startup fails when cluster has insufficient resources.

Expand Down