From b7dddceeee8e577dee3ec8c6eee5f6ceb326b589 Mon Sep 17 00:00:00 2001 From: Qing Wang Date: Thu, 26 Mar 2026 17:54:44 +0000 Subject: [PATCH] fix: replace anyio task group with asyncio tasks to fix cross-task cancel scope error (#454) Replace manual anyio TaskGroup.__aenter__/__aexit__ calls with asyncio.create_task() for background task management in Query. The anyio TaskGroup pattern required cancel scopes to be entered and exited in the same async task. When users break from the async generator returned by query(), Python may finalize the generator in a different task, causing close() to call __aexit__ from a different task than start() called __aenter__. This produced a RuntimeError: 'Attempted to exit cancel scope in a different task than it was entered in' The fix uses asyncio.create_task() which has no cancel scope affinity, allowing close() to cancel the read task from any task context. A new spawn_task() method replaces _tg.start_soon() for child tasks. :house: Remote-Dev: homespace --- src/claude_agent_sdk/_internal/client.py | 4 +- src/claude_agent_sdk/_internal/query.py | 34 +++++++++------ src/claude_agent_sdk/client.py | 4 +- tests/test_query.py | 53 ++++++++++++++++++++++++ 4 files changed, 79 insertions(+), 16 deletions(-) diff --git a/src/claude_agent_sdk/_internal/client.py b/src/claude_agent_sdk/_internal/client.py index 07b5e9b3..0a7107a8 100644 --- a/src/claude_agent_sdk/_internal/client.py +++ b/src/claude_agent_sdk/_internal/client.py @@ -139,9 +139,9 @@ async def process_query( } await chosen_transport.write(json.dumps(user_message) + "\n") await query.wait_for_result_and_end_input() - elif isinstance(prompt, AsyncIterable) and query._tg: + elif isinstance(prompt, AsyncIterable): # Stream input in background for async iterables - query._tg.start_soon(query.stream_input, prompt) + query.spawn_task(query.stream_input(prompt)) # Yield parsed messages, skipping unknown message types async for data in query.receive_messages(): diff --git a/src/claude_agent_sdk/_internal/query.py b/src/claude_agent_sdk/_internal/query.py index 4a70d79c..5c01e36c 100644 --- a/src/claude_agent_sdk/_internal/query.py +++ b/src/claude_agent_sdk/_internal/query.py @@ -1,5 +1,6 @@ """Query class for handling bidirectional control protocol.""" +import asyncio import json import logging import os @@ -106,7 +107,8 @@ def __init__( self._message_send, self._message_receive = anyio.create_memory_object_stream[ dict[str, Any] ](max_buffer_size=100) - self._tg: anyio.abc.TaskGroup | None = None + self._read_task: asyncio.Task[None] | None = None + self._child_tasks: set[asyncio.Task[Any]] = set() self._initialized = False self._closed = False self._initialization_result: dict[str, Any] | None = None @@ -162,10 +164,16 @@ async def initialize(self) -> dict[str, Any] | None: async def start(self) -> None: """Start reading messages from transport.""" - if self._tg is None: - self._tg = anyio.create_task_group() - await self._tg.__aenter__() - self._tg.start_soon(self._read_messages) + if self._read_task is None: + loop = asyncio.get_running_loop() + self._read_task = loop.create_task(self._read_messages()) + + def spawn_task(self, coro: Any) -> None: + """Spawn a child task that will be cancelled on close().""" + loop = asyncio.get_running_loop() + task = loop.create_task(coro) + self._child_tasks.add(task) + task.add_done_callback(self._child_tasks.discard) async def _read_messages(self) -> None: """Read messages from transport and route them.""" @@ -195,8 +203,8 @@ async def _read_messages(self) -> None: # Handle incoming control requests from CLI # Cast message to SDKControlRequest for type safety request: SDKControlRequest = message # type: ignore[assignment] - if self._tg: - self._tg.start_soon(self._handle_control_request, request) + if not self._closed: + self.spawn_task(self._handle_control_request(request)) continue elif msg_type == "control_cancel_request": @@ -694,11 +702,13 @@ async def receive_messages(self) -> AsyncIterator[dict[str, Any]]: async def close(self) -> None: """Close the query and transport.""" self._closed = True - if self._tg: - self._tg.cancel_scope.cancel() - # Wait for task group to complete cancellation - with suppress(anyio.get_cancelled_exc_class()): - await self._tg.__aexit__(None, None, None) + for task in list(self._child_tasks): + task.cancel() + if self._read_task is not None and not self._read_task.done(): + self._read_task.cancel() + with suppress(asyncio.CancelledError): + await self._read_task + self._read_task = None await self.transport.close() # Make Query an async iterator diff --git a/src/claude_agent_sdk/client.py b/src/claude_agent_sdk/client.py index 3e66b4cc..1b2af70c 100644 --- a/src/claude_agent_sdk/client.py +++ b/src/claude_agent_sdk/client.py @@ -180,8 +180,8 @@ async def _empty_stream() -> AsyncIterator[dict[str, Any]]: await self._query.initialize() # If we have an initial prompt stream, start streaming it - if prompt is not None and isinstance(prompt, AsyncIterable) and self._query._tg: - self._query._tg.start_soon(self._query.stream_input, prompt) + if prompt is not None and isinstance(prompt, AsyncIterable): + self._query.spawn_task(self._query.stream_input(prompt)) async def receive_messages(self) -> AsyncIterator[Message]: """Receive all messages from Claude.""" diff --git a/tests/test_query.py b/tests/test_query.py index f6b8a590..7df6cea0 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -492,3 +492,56 @@ async def _test(): mock_transport.end_input.assert_called_once() anyio.run(_test) + + +class TestQueryCrossTaskCleanup: + """Tests for cross-task cleanup of Query task groups (issue #454). + + When a user breaks out of an async for loop over process_query(), Python + finalizes the async generator in a different task than the one that called + start(). This triggers close() from a different task context, which causes + anyio to raise RuntimeError because cancel scopes must be exited by the + same task that entered them. These tests verify that close() handles this + gracefully. + """ + + def test_close_from_different_task_does_not_raise(self): + """close() called from a different task than start() must not raise.""" + import asyncio + + async def _test(): + mock_transport = _make_mock_transport(messages=[]) + q = Query(transport=mock_transport, is_streaming_mode=True) + + await q.start() + + close_error = None + + async def close_in_other_task(): + nonlocal close_error + try: + await q.close() + except Exception as e: + close_error = e + + task = asyncio.create_task(close_in_other_task()) + await task + + assert close_error is None, f"close() raised: {close_error}" + + asyncio.run(_test()) + + def test_close_from_same_task_still_works(self): + """close() from the same task as start() should still work normally.""" + + async def _test(): + mock_transport = _make_mock_transport(messages=[]) + q = Query(transport=mock_transport, is_streaming_mode=True) + + await q.start() + await q.close() + + assert q._read_task is None + mock_transport.close.assert_called_once() + + anyio.run(_test)