diff --git a/src/claude_agent_sdk/_internal/client.py b/src/claude_agent_sdk/_internal/client.py index 07b5e9b3..0a7107a8 100644 --- a/src/claude_agent_sdk/_internal/client.py +++ b/src/claude_agent_sdk/_internal/client.py @@ -139,9 +139,9 @@ async def process_query( } await chosen_transport.write(json.dumps(user_message) + "\n") await query.wait_for_result_and_end_input() - elif isinstance(prompt, AsyncIterable) and query._tg: + elif isinstance(prompt, AsyncIterable): # Stream input in background for async iterables - query._tg.start_soon(query.stream_input, prompt) + query.spawn_task(query.stream_input(prompt)) # Yield parsed messages, skipping unknown message types async for data in query.receive_messages(): diff --git a/src/claude_agent_sdk/_internal/query.py b/src/claude_agent_sdk/_internal/query.py index 4a70d79c..5c01e36c 100644 --- a/src/claude_agent_sdk/_internal/query.py +++ b/src/claude_agent_sdk/_internal/query.py @@ -1,5 +1,6 @@ """Query class for handling bidirectional control protocol.""" +import asyncio import json import logging import os @@ -106,7 +107,8 @@ def __init__( self._message_send, self._message_receive = anyio.create_memory_object_stream[ dict[str, Any] ](max_buffer_size=100) - self._tg: anyio.abc.TaskGroup | None = None + self._read_task: asyncio.Task[None] | None = None + self._child_tasks: set[asyncio.Task[Any]] = set() self._initialized = False self._closed = False self._initialization_result: dict[str, Any] | None = None @@ -162,10 +164,16 @@ async def initialize(self) -> dict[str, Any] | None: async def start(self) -> None: """Start reading messages from transport.""" - if self._tg is None: - self._tg = anyio.create_task_group() - await self._tg.__aenter__() - self._tg.start_soon(self._read_messages) + if self._read_task is None: + loop = asyncio.get_running_loop() + self._read_task = loop.create_task(self._read_messages()) + + def spawn_task(self, coro: Any) -> None: + """Spawn a child task that will be cancelled on close().""" + loop = asyncio.get_running_loop() + task = loop.create_task(coro) + self._child_tasks.add(task) + task.add_done_callback(self._child_tasks.discard) async def _read_messages(self) -> None: """Read messages from transport and route them.""" @@ -195,8 +203,8 @@ async def _read_messages(self) -> None: # Handle incoming control requests from CLI # Cast message to SDKControlRequest for type safety request: SDKControlRequest = message # type: ignore[assignment] - if self._tg: - self._tg.start_soon(self._handle_control_request, request) + if not self._closed: + self.spawn_task(self._handle_control_request(request)) continue elif msg_type == "control_cancel_request": @@ -694,11 +702,13 @@ async def receive_messages(self) -> AsyncIterator[dict[str, Any]]: async def close(self) -> None: """Close the query and transport.""" self._closed = True - if self._tg: - self._tg.cancel_scope.cancel() - # Wait for task group to complete cancellation - with suppress(anyio.get_cancelled_exc_class()): - await self._tg.__aexit__(None, None, None) + for task in list(self._child_tasks): + task.cancel() + if self._read_task is not None and not self._read_task.done(): + self._read_task.cancel() + with suppress(asyncio.CancelledError): + await self._read_task + self._read_task = None await self.transport.close() # Make Query an async iterator diff --git a/src/claude_agent_sdk/client.py b/src/claude_agent_sdk/client.py index 3e66b4cc..1b2af70c 100644 --- a/src/claude_agent_sdk/client.py +++ b/src/claude_agent_sdk/client.py @@ -180,8 +180,8 @@ async def _empty_stream() -> AsyncIterator[dict[str, Any]]: await self._query.initialize() # If we have an initial prompt stream, start streaming it - if prompt is not None and isinstance(prompt, AsyncIterable) and self._query._tg: - self._query._tg.start_soon(self._query.stream_input, prompt) + if prompt is not None and isinstance(prompt, AsyncIterable): + self._query.spawn_task(self._query.stream_input(prompt)) async def receive_messages(self) -> AsyncIterator[Message]: """Receive all messages from Claude.""" diff --git a/tests/test_query.py b/tests/test_query.py index f6b8a590..7df6cea0 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -492,3 +492,56 @@ async def _test(): mock_transport.end_input.assert_called_once() anyio.run(_test) + + +class TestQueryCrossTaskCleanup: + """Tests for cross-task cleanup of Query task groups (issue #454). + + When a user breaks out of an async for loop over process_query(), Python + finalizes the async generator in a different task than the one that called + start(). This triggers close() from a different task context, which causes + anyio to raise RuntimeError because cancel scopes must be exited by the + same task that entered them. These tests verify that close() handles this + gracefully. + """ + + def test_close_from_different_task_does_not_raise(self): + """close() called from a different task than start() must not raise.""" + import asyncio + + async def _test(): + mock_transport = _make_mock_transport(messages=[]) + q = Query(transport=mock_transport, is_streaming_mode=True) + + await q.start() + + close_error = None + + async def close_in_other_task(): + nonlocal close_error + try: + await q.close() + except Exception as e: + close_error = e + + task = asyncio.create_task(close_in_other_task()) + await task + + assert close_error is None, f"close() raised: {close_error}" + + asyncio.run(_test()) + + def test_close_from_same_task_still_works(self): + """close() from the same task as start() should still work normally.""" + + async def _test(): + mock_transport = _make_mock_transport(messages=[]) + q = Query(transport=mock_transport, is_streaming_mode=True) + + await q.start() + await q.close() + + assert q._read_task is None + mock_transport.close.assert_called_once() + + anyio.run(_test)