Skip to content

Commit 480a4de

Browse files
[train] Cleanup Zombie RayTrainWorker Actors (#59872)
Leaking Ray Train actors have been observed occupying GPU memory following Train run termination, causing training failures/OOMs in subsequent train runs. Despite the train actors being marked DEAD by Ray Core, we find that upon ssh-ing into nodes, that the actor processes are still alive and occupying valuable GPU memory. This PR: - Replaces `__ray_terminate__` with `ray.kill` in Train run shutdown and abort paths to guarantee the termination of train actors --------- Signed-off-by: JasonLi1909 <jasli1909@gmail.com> Signed-off-by: Jason Li <57246540+JasonLi1909@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 39bc94b commit 480a4de

File tree

2 files changed

+82
-18
lines changed

2 files changed

+82
-18
lines changed

python/ray/train/v2/_internal/execution/worker_group/state.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -126,25 +126,18 @@ def shutdown(self):
126126

127127

128128
def _shutdown_workers(workers: List[Worker], patience_s: float = 5):
129-
# Run the worker shutdown logic on each of the workers. This should
130-
# be a non-blocking call to realize forceful shutdown after patience_s.
131-
_ = [w.actor.shutdown.remote() for w in workers]
129+
"""Shuts down workers after allowing a maximum of patience_s seconds for shutdown hooks to run."""
130+
if patience_s < 0:
131+
raise ValueError("Invalid patience_s: must be non-negative")
132+
133+
done_refs = [w.actor.shutdown.remote() for w in workers]
132134

133135
logger.debug(f"Shutting down {len(workers)} workers.")
134-
if patience_s <= 0:
135-
for worker in workers:
136-
ray.kill(worker.actor)
137-
else:
138-
done_refs = [w.actor.__ray_terminate__.remote() for w in workers]
139-
# Wait for actors to die gracefully.
140-
_, not_done = ray.wait(
141-
done_refs, num_returns=len(done_refs), timeout=patience_s
142-
)
143-
if not_done:
144-
logger.debug("Graceful termination failed. Falling back to force kill.")
145-
# If all actors are not able to die gracefully, then kill them.
146-
for worker in workers:
147-
ray.kill(worker.actor)
136+
137+
ray.wait(done_refs, num_returns=len(done_refs), timeout=patience_s)
138+
139+
for worker in workers:
140+
ray.kill(worker.actor)
148141

149142

150143
def _shutdown_sync_actor(sync_actor: SynchronizationActor):

python/ray/train/v2/tests/test_worker_group.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,10 @@
3030
WorkerGroupContext,
3131
WorkerGroupState,
3232
)
33-
from ray.train.v2.api.config import RunConfig
33+
from ray.train.v2._internal.util import ObjectRefWrapper
34+
from ray.train.v2.api.config import RunConfig, ScalingConfig
3435
from ray.train.v2.tests.util import DummyObjectRefWrapper, create_dummy_run_context
36+
from ray.util.state import list_actors
3537

3638
pytestmark = pytest.mark.usefixtures("mock_runtime_context")
3739

@@ -161,6 +163,75 @@ def hanging_task(*args, **kwargs):
161163
wg._start()
162164

163165

166+
def test_zombie_actor_termination(ray_start_4_cpus):
167+
"""This test checks that RayTrainWorker actors are terminated correctly even if python garbage collection hangs on actor shutdown."""
168+
NUM_WORKERS = 4
169+
170+
def is_process_alive(pid: int) -> bool:
171+
try:
172+
os.kill(pid, 0)
173+
except ProcessLookupError:
174+
return False
175+
except PermissionError:
176+
return True
177+
else:
178+
return True
179+
180+
class Node:
181+
def __init__(self, name):
182+
self.name = name
183+
self.other = None
184+
185+
def __del__(self):
186+
# Simulate hang during garbage collection
187+
while True:
188+
time.sleep(1)
189+
190+
def train_fn():
191+
# Create a circular reference to delay garbage collection
192+
a, b = Node("a"), Node("b")
193+
a.other = b
194+
b.other = a
195+
196+
train_fn_ref = ObjectRefWrapper(train_fn)
197+
198+
train_run_context = create_dummy_run_context(
199+
scaling_config=ScalingConfig(num_workers=NUM_WORKERS)
200+
)
201+
worker_group_context = _default_worker_group_context(
202+
train_fn_ref=train_fn_ref,
203+
num_workers=NUM_WORKERS,
204+
)
205+
206+
# Starts the worker group and runs the train function
207+
worker_group = WorkerGroup.create(
208+
train_run_context=train_run_context,
209+
worker_group_context=worker_group_context,
210+
callbacks=[],
211+
)
212+
213+
train_worker_pids = [
214+
actor.pid
215+
for actor in list_actors()
216+
if actor.class_name == RayTrainWorker.__name__ and actor.state == "ALIVE"
217+
]
218+
219+
assert len(train_worker_pids) == NUM_WORKERS
220+
221+
worker_group.shutdown()
222+
223+
# ray.kill is async, allow some time for the processes to terminate
224+
TIMEOUT_S = 5
225+
deadline = time.monotonic() + TIMEOUT_S
226+
remaining = set(train_worker_pids)
227+
while remaining and time.monotonic() < deadline:
228+
remaining = {pid for pid in remaining if is_process_alive(pid)}
229+
if remaining:
230+
time.sleep(0.1)
231+
232+
assert not remaining
233+
234+
164235
def test_insufficient_cluster_resources_startup_failure(monkeypatch):
165236
"""Test that WorkerGroup startup fails when cluster has insufficient resources.
166237

0 commit comments

Comments
 (0)