Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/claude_agent_sdk/_internal/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,9 @@ async def process_query(
}
await chosen_transport.write(json.dumps(user_message) + "\n")
await query.wait_for_result_and_end_input()
elif isinstance(prompt, AsyncIterable) and query._tg:
elif isinstance(prompt, AsyncIterable):
# Stream input in background for async iterables
query._tg.start_soon(query.stream_input, prompt)
query.spawn_task(query.stream_input(prompt))

# Yield parsed messages, skipping unknown message types
async for data in query.receive_messages():
Expand Down
34 changes: 22 additions & 12 deletions src/claude_agent_sdk/_internal/query.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Query class for handling bidirectional control protocol."""

import asyncio
import json
import logging
import os
Expand Down Expand Up @@ -106,7 +107,8 @@ def __init__(
self._message_send, self._message_receive = anyio.create_memory_object_stream[
dict[str, Any]
](max_buffer_size=100)
self._tg: anyio.abc.TaskGroup | None = None
self._read_task: asyncio.Task[None] | None = None
self._child_tasks: set[asyncio.Task[Any]] = set()
self._initialized = False
self._closed = False
self._initialization_result: dict[str, Any] | None = None
Expand Down Expand Up @@ -162,10 +164,16 @@ async def initialize(self) -> dict[str, Any] | None:

async def start(self) -> None:
"""Start reading messages from transport."""
if self._tg is None:
self._tg = anyio.create_task_group()
await self._tg.__aenter__()
self._tg.start_soon(self._read_messages)
if self._read_task is None:
loop = asyncio.get_running_loop()
self._read_task = loop.create_task(self._read_messages())

def spawn_task(self, coro: Any) -> None:
"""Spawn a child task that will be cancelled on close()."""
loop = asyncio.get_running_loop()
task = loop.create_task(coro)
self._child_tasks.add(task)
task.add_done_callback(self._child_tasks.discard)

async def _read_messages(self) -> None:
"""Read messages from transport and route them."""
Expand Down Expand Up @@ -195,8 +203,8 @@ async def _read_messages(self) -> None:
# Handle incoming control requests from CLI
# Cast message to SDKControlRequest for type safety
request: SDKControlRequest = message # type: ignore[assignment]
if self._tg:
self._tg.start_soon(self._handle_control_request, request)
if not self._closed:
self.spawn_task(self._handle_control_request(request))
continue

elif msg_type == "control_cancel_request":
Expand Down Expand Up @@ -694,11 +702,13 @@ async def receive_messages(self) -> AsyncIterator[dict[str, Any]]:
async def close(self) -> None:
"""Close the query and transport."""
self._closed = True
if self._tg:
self._tg.cancel_scope.cancel()
# Wait for task group to complete cancellation
with suppress(anyio.get_cancelled_exc_class()):
await self._tg.__aexit__(None, None, None)
for task in list(self._child_tasks):
task.cancel()
if self._read_task is not None and not self._read_task.done():
self._read_task.cancel()
with suppress(asyncio.CancelledError):
await self._read_task
self._read_task = None
await self.transport.close()

# Make Query an async iterator
Expand Down
4 changes: 2 additions & 2 deletions src/claude_agent_sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,8 @@ async def _empty_stream() -> AsyncIterator[dict[str, Any]]:
await self._query.initialize()

# If we have an initial prompt stream, start streaming it
if prompt is not None and isinstance(prompt, AsyncIterable) and self._query._tg:
self._query._tg.start_soon(self._query.stream_input, prompt)
if prompt is not None and isinstance(prompt, AsyncIterable):
self._query.spawn_task(self._query.stream_input(prompt))

async def receive_messages(self) -> AsyncIterator[Message]:
"""Receive all messages from Claude."""
Expand Down
53 changes: 53 additions & 0 deletions tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,3 +492,56 @@ async def _test():
mock_transport.end_input.assert_called_once()

anyio.run(_test)


class TestQueryCrossTaskCleanup:
"""Tests for cross-task cleanup of Query task groups (issue #454).

When a user breaks out of an async for loop over process_query(), Python
finalizes the async generator in a different task than the one that called
start(). This triggers close() from a different task context, which causes
anyio to raise RuntimeError because cancel scopes must be exited by the
same task that entered them. These tests verify that close() handles this
gracefully.
"""

def test_close_from_different_task_does_not_raise(self):
"""close() called from a different task than start() must not raise."""
import asyncio

async def _test():
mock_transport = _make_mock_transport(messages=[])
q = Query(transport=mock_transport, is_streaming_mode=True)

await q.start()

close_error = None

async def close_in_other_task():
nonlocal close_error
try:
await q.close()
except Exception as e:
close_error = e

task = asyncio.create_task(close_in_other_task())
await task

assert close_error is None, f"close() raised: {close_error}"

asyncio.run(_test())

def test_close_from_same_task_still_works(self):
"""close() from the same task as start() should still work normally."""

async def _test():
mock_transport = _make_mock_transport(messages=[])
q = Query(transport=mock_transport, is_streaming_mode=True)

await q.start()
await q.close()

assert q._read_task is None
mock_transport.close.assert_called_once()

anyio.run(_test)
Loading