Skip to content

Commit 32b05cc

Browse files
fix(telemetry): drop use_span from _TracedStream.__anext__; eager-prime SSE under captured ctx (#277)
* fix(telemetry): drop use_span from _TracedStream.__anext__; eager-prime SSE under captured ctx `_TracedStream.__anext__` wrapped `await self._inner.__anext__()` in `with use_span(self._span, end_on_exit=False):` to make the GenAI span the parent of HTTPX child spans. Under FastAPI StreamingResponse the attach/detach pair fired across asyncio tasks, raising `ValueError: Token was created in a different Context` ~19x per chat session. Closes #276. Capture an OTel context snapshot with the GenAI span active in `_stream` (sync `with use_span`, no await between attach and detach). Schedule the SSE iterator's first __anext__ as `asyncio.create_task(coro, context=ctx)` so the HTTP request fires under the snapshot; HTTPX auto-instrumentation captures the GenAI span as parent at request emission. Subsequent reads run in the consumer's own context and pull bytes from the open response. Drop the `with use_span` block from `_TracedStream.__anext__`. * fix(telemetry): cancel prime task when consumer is cancelled before first chunk Without this, a cancellation interrupting `await task` in `_prime_with_context` left the underlying create_task running in the background, holding the open SSE/HTTP connection. Add a `BaseException` catch that calls `task.cancel()` then re-raises. Add regression test covering cancel-mid-first-pull. * refactor(telemetry): move bind_first_pull_to_span helper from client.py to telemetry.py The first pass parked an async-task + contextvars helper at the top of client.py and added asyncio + contextvars imports there. That leaked telemetry mechanism into the module that defines the SDK base layering. Move the helper to telemetry.py (where it joins trace_stream / record_* / use_span and where contextvars belongs); rename to bind_first_pull_to_span; revert client.py imports and drop the inline ctx-capture block from `_stream`. `_stream` regains its single-verb orchestration rhythm with one new line: sse_iterator = telemetry.bind_first_pull_to_span(sse_iterator, span) Same behavior, same tests (signatures updated to pass a span instead of a Context; new test asserts the first pull sees the span as current OTel context and subsequent pulls don't).
1 parent a1bd78b commit 32b05cc

3 files changed

Lines changed: 146 additions & 13 deletions

File tree

src/celeste/client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,7 @@ def _stream(
313313
**parameters,
314314
)
315315
sse_iterator = enrich_stream_errors(sse_iterator, self._handle_error_response)
316+
sse_iterator = telemetry.bind_first_pull_to_span(sse_iterator, span)
316317
stream = stream_class(
317318
sse_iterator,
318319
transform_output=self._transform_output,

src/celeste/telemetry.py

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
"""OpenTelemetry GenAI telemetry for Celeste."""
22

33
import asyncio
4+
import contextvars
45
import json
56
import os
67
import time
7-
from collections.abc import Iterator
8+
from collections.abc import AsyncIterator, Iterator
89
from contextlib import contextmanager, suppress
910
from types import TracebackType
1011
from typing import Any
@@ -435,18 +436,17 @@ def __aiter__(self) -> "_TracedStream":
435436

436437
async def __anext__(self) -> Any:
437438
"""Yield next chunk; emit TTFC on first, finalize span on terminal events."""
438-
with use_span(self._span, end_on_exit=False):
439-
try:
440-
chunk = await self._inner.__anext__()
441-
except (StopAsyncIteration, asyncio.CancelledError):
442-
self._finalize()
443-
raise
444-
except Exception as exc:
445-
self._error = exc
446-
self._span.record_exception(exc)
447-
self._span.set_status(Status(StatusCode.ERROR, str(exc)))
448-
self._finalize()
449-
raise
439+
try:
440+
chunk = await self._inner.__anext__()
441+
except (StopAsyncIteration, asyncio.CancelledError):
442+
self._finalize()
443+
raise
444+
except Exception as exc:
445+
self._error = exc
446+
self._span.record_exception(exc)
447+
self._span.set_status(Status(StatusCode.ERROR, str(exc)))
448+
self._finalize()
449+
raise
450450
if not self._seen_first:
451451
self._seen_first = True
452452
self._span.set_attribute(
@@ -541,10 +541,35 @@ def trace_stream(
541541
return _TracedStream(stream, span, metric_attributes)
542542

543543

544+
async def bind_first_pull_to_span(
545+
inner: AsyncIterator[dict[str, Any]],
546+
span: Any,
547+
) -> AsyncIterator[dict[str, Any]]:
548+
"""Run inner's first pull under a context where span is active; delegate the rest."""
549+
with use_span(span):
550+
ctx = contextvars.copy_context()
551+
552+
async def _first() -> dict[str, Any]:
553+
return await inner.__anext__()
554+
555+
task = asyncio.create_task(_first(), context=ctx)
556+
try:
557+
first = await task
558+
except StopAsyncIteration:
559+
return
560+
except BaseException:
561+
task.cancel()
562+
raise
563+
yield first
564+
async for event in inner:
565+
yield event
566+
567+
544568
__all__ = [
545569
"Status",
546570
"StatusCode",
547571
"add_input_event",
572+
"bind_first_pull_to_span",
548573
"gen_ai_span",
549574
"meter",
550575
"output_attributes",

tests/unit_tests/test_telemetry_streaming.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
"""Tests for `_TracedStream` — span lifecycle and GenAI attribute emission."""
22

3+
import asyncio
34
from collections.abc import AsyncIterator
45
from typing import Any
6+
from unittest.mock import patch
57

68
import pytest
9+
from opentelemetry import trace as otel_trace
710
from opentelemetry.sdk.trace import TracerProvider
811
from opentelemetry.sdk.trace.export.in_memory_span_exporter import (
912
InMemorySpanExporter,
@@ -235,3 +238,107 @@ def test_response_model_emitted_from_metadata(self) -> None:
235238
attrs = telemetry.output_attributes(output)
236239

237240
assert attrs["gen_ai.response.model"] == "claude-opus-4-1-20250805"
241+
242+
243+
class TestNoSpanActivationInAnext:
244+
"""Regression: `_TracedStream.__anext__` must not activate the span across `await`."""
245+
246+
async def test_anext_does_not_call_use_span(
247+
self, exporter: tuple[InMemorySpanExporter, TracerProvider]
248+
) -> None:
249+
"""Iterating a `_TracedStream` does NOT invoke `telemetry.use_span`."""
250+
_, provider = exporter
251+
events = [{"delta": "a"}, {"delta": "b"}]
252+
wrapped = telemetry.trace_stream(
253+
TelemetryStream(async_iter(events)), start_test_span(provider)
254+
)
255+
256+
with patch("celeste.telemetry.use_span") as mock_use_span:
257+
async for _ in wrapped:
258+
pass
259+
260+
mock_use_span.assert_not_called()
261+
262+
263+
class TestBindFirstPullToSpan:
264+
"""`bind_first_pull_to_span` runs the first pull with span active; delegates the rest."""
265+
266+
async def test_preserves_event_order(
267+
self, exporter: tuple[InMemorySpanExporter, TracerProvider]
268+
) -> None:
269+
"""All events from inner are yielded in original order."""
270+
_, provider = exporter
271+
events = [{"i": 0}, {"i": 1}, {"i": 2}]
272+
bound = telemetry.bind_first_pull_to_span(
273+
async_iter(events), start_test_span(provider)
274+
)
275+
276+
collected = [event async for event in bound]
277+
278+
assert collected == events
279+
280+
async def test_empty_stream_yields_nothing(
281+
self, exporter: tuple[InMemorySpanExporter, TracerProvider]
282+
) -> None:
283+
"""Inner that immediately raises StopAsyncIteration yields no events."""
284+
_, provider = exporter
285+
bound = telemetry.bind_first_pull_to_span(
286+
async_iter([]), start_test_span(provider)
287+
)
288+
289+
collected = [event async for event in bound]
290+
291+
assert collected == []
292+
293+
async def test_first_error_propagates(
294+
self, exporter: tuple[InMemorySpanExporter, TracerProvider]
295+
) -> None:
296+
"""Exception raised by inner's first pull propagates to the consumer."""
297+
_, provider = exporter
298+
bound = telemetry.bind_first_pull_to_span(
299+
_failing_iter(), start_test_span(provider)
300+
)
301+
302+
with pytest.raises(RuntimeError, match="boom"):
303+
async for _ in bound:
304+
pass
305+
306+
async def test_first_pull_runs_with_span_active(
307+
self, exporter: tuple[InMemorySpanExporter, TracerProvider]
308+
) -> None:
309+
"""First inner pull sees span as current OTel context; subsequent pulls don't."""
310+
_, provider = exporter
311+
span = start_test_span(provider)
312+
target_id = span.get_span_context().span_id
313+
314+
async def inner() -> AsyncIterator[dict[str, int]]:
315+
yield {"span_id": otel_trace.get_current_span().get_span_context().span_id}
316+
yield {"span_id": otel_trace.get_current_span().get_span_context().span_id}
317+
318+
bound = telemetry.bind_first_pull_to_span(inner(), span)
319+
collected = [event async for event in bound]
320+
321+
assert collected[0]["span_id"] == target_id
322+
assert collected[1]["span_id"] != target_id
323+
324+
async def test_cancel_during_first_pull_cancels_inner_task(
325+
self, exporter: tuple[InMemorySpanExporter, TracerProvider]
326+
) -> None:
327+
"""Cancellation while awaiting the first pull propagates to the inner task — no leak."""
328+
_, provider = exporter
329+
started = asyncio.Event()
330+
331+
async def slow_inner() -> AsyncIterator[dict[str, str]]:
332+
started.set()
333+
await asyncio.sleep(60)
334+
yield {"v": "never"}
335+
336+
bound = telemetry.bind_first_pull_to_span(
337+
slow_inner(), start_test_span(provider)
338+
)
339+
consumer: asyncio.Task[dict[str, str]] = asyncio.create_task(bound.__anext__())
340+
await started.wait()
341+
consumer.cancel()
342+
with pytest.raises(asyncio.CancelledError):
343+
await consumer
344+
await asyncio.sleep(0)

0 commit comments

Comments
 (0)