Skip to content

Commit b4cf340

Browse files
refactor(telemetry): extract gen_ai_span context manager — drop try/except + import time from client
User pushed back on `_predict` wrapping its entire body in try/except just to record `gen_ai.client.operation.duration` on the error path. The right home for that is a context manager inside `telemetry.py`. `gen_ai_span(model=, provider=, protocol=, modality=)` opens the span via `tracer.start_as_current_span`, captures the start time on enter, and in `finally` records the operation duration with `error.type` populated when the body raised. It yields `(span, request_attrs)` so `_predict` can call `add_input_event` and `record_output` against them. Result: - `_predict` is straight-line code: no try/except, no manual time math. - `client.py` no longer imports `time`. - `record_output` no longer takes `duration_seconds` / `error` params — duration is the span's lifecycle concern, not the output recorder's. - Streaming side (`_TracedStream._finalize`) updates symmetrically: records duration directly, then calls the simplified `record_output`. Real-call validated against gemini-3.1-flash-lite-preview text non-streaming + async streaming: same span attributes, same metric histograms, same content events as before.
1 parent fcc48af commit b4cf340

2 files changed

Lines changed: 76 additions & 58 deletions

File tree

src/celeste/client.py

Lines changed: 32 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""Base client for modality-specific AI operations."""
22

3-
import time
43
import warnings
54
from abc import ABC, abstractmethod
65
from collections.abc import AsyncIterator
@@ -214,55 +213,42 @@ async def _predict(
214213
Returns:
215214
Output of the parameterized type.
216215
"""
217-
request_attrs = telemetry.request_attributes(
216+
with telemetry.gen_ai_span(
218217
model=self.model,
219218
provider=self.provider,
220219
protocol=self.protocol,
221220
modality=self.modality,
222-
)
223-
started = time.monotonic()
224-
with telemetry.tracer.start_as_current_span(
225-
telemetry.span_name(self.modality, self.model),
226-
attributes=request_attrs,
227-
) as span:
228-
try:
229-
inputs, parameters = self._validate_artifacts(inputs, **parameters)
230-
telemetry.add_input_event(span, inputs)
231-
request_body = self._build_request(
232-
inputs, extra_body=extra_body, **parameters
233-
)
234-
response_data = await self._make_request(
235-
request_body,
236-
endpoint=endpoint,
237-
extra_headers=extra_headers,
238-
**parameters,
239-
)
240-
content = self._parse_content(response_data)
241-
content = self._transform_output(content, **parameters)
242-
tool_calls = self._parse_tool_calls(response_data)
243-
reasoning, signature = self._parse_reasoning(response_data)
244-
kwargs: dict[str, Any] = {}
245-
if reasoning is not None:
246-
kwargs["reasoning"] = reasoning
247-
if signature:
248-
kwargs["signature"] = signature
249-
output = self._output_class()(
250-
content=content,
251-
usage=self._get_usage(response_data),
252-
finish_reason=self._get_finish_reason(response_data),
253-
metadata=self._build_metadata(response_data),
254-
tool_calls=tool_calls,
255-
**kwargs,
256-
)
257-
telemetry.record_output(
258-
span, output, request_attrs, time.monotonic() - started
259-
)
260-
return output
261-
except BaseException as exc:
262-
telemetry.record_operation_duration(
263-
time.monotonic() - started, request_attrs, error=exc
264-
)
265-
raise
221+
) as (span, request_attrs):
222+
inputs, parameters = self._validate_artifacts(inputs, **parameters)
223+
telemetry.add_input_event(span, inputs)
224+
request_body = self._build_request(
225+
inputs, extra_body=extra_body, **parameters
226+
)
227+
response_data = await self._make_request(
228+
request_body,
229+
endpoint=endpoint,
230+
extra_headers=extra_headers,
231+
**parameters,
232+
)
233+
content = self._parse_content(response_data)
234+
content = self._transform_output(content, **parameters)
235+
tool_calls = self._parse_tool_calls(response_data)
236+
reasoning, signature = self._parse_reasoning(response_data)
237+
kwargs: dict[str, Any] = {}
238+
if reasoning is not None:
239+
kwargs["reasoning"] = reasoning
240+
if signature:
241+
kwargs["signature"] = signature
242+
output = self._output_class()(
243+
content=content,
244+
usage=self._get_usage(response_data),
245+
finish_reason=self._get_finish_reason(response_data),
246+
metadata=self._build_metadata(response_data),
247+
tool_calls=tool_calls,
248+
**kwargs,
249+
)
250+
telemetry.record_output(span, output, request_attrs)
251+
return output
266252

267253
def _parse_tool_calls(self, response_data: dict[str, Any]) -> list[ToolCall]:
268254
"""Parse tool calls from response. Override in providers that support tools."""

src/celeste/telemetry.py

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -365,16 +365,49 @@ def record_output(
365365
span: Any,
366366
output: Output[Any],
367367
metric_attributes: dict[str, Any],
368-
duration_seconds: float,
369-
error: BaseException | None = None,
370368
) -> None:
371-
"""Emit span attrs, content event, and metrics for a successful Output."""
369+
"""Emit span attrs, content event, and token usage for a successful Output."""
372370
span.set_attributes(output_attributes(output))
373371
output_event = _output_messages_event(output)
374372
if output_event is not None:
375373
span.add_event("gen_ai.output.messages", attributes=output_event)
376374
record_token_usage(output.usage, metric_attributes)
377-
record_operation_duration(duration_seconds, metric_attributes, error=error)
375+
376+
377+
@contextmanager
378+
def gen_ai_span(
379+
*,
380+
model: Model,
381+
provider: Provider | None,
382+
protocol: Protocol | None,
383+
modality: Modality,
384+
extra_attributes: dict[str, Any] | None = None,
385+
) -> Iterator[tuple[Any, dict[str, Any]]]:
386+
"""Open a GenAI span and record operation duration on exit.
387+
388+
Yields ``(span, request_attrs)``. On any exception, the duration is recorded
389+
with ``error.type`` set; on success it's recorded plain.
390+
"""
391+
request_attrs = request_attributes(
392+
model=model, provider=provider, protocol=protocol, modality=modality
393+
)
394+
span_attrs = (
395+
{**request_attrs, **extra_attributes} if extra_attributes else request_attrs
396+
)
397+
started = time.monotonic()
398+
error: BaseException | None = None
399+
with tracer.start_as_current_span(
400+
span_name(modality, model), attributes=span_attrs
401+
) as span:
402+
try:
403+
yield span, request_attrs
404+
except BaseException as exc:
405+
error = exc
406+
raise
407+
finally:
408+
record_operation_duration(
409+
time.monotonic() - started, request_attrs, error=error
410+
)
378411

379412

380413
class _TracedStream:
@@ -484,19 +517,17 @@ def _finalize(self) -> None:
484517
if self._ended:
485518
return
486519
self._ended = True
487-
duration = time.monotonic() - self._started
520+
record_operation_duration(
521+
time.monotonic() - self._started,
522+
self._metric_attributes,
523+
error=self._error,
524+
)
488525
try:
489526
output = self._inner.output
490527
except StreamNotExhaustedError:
491528
output = None
492529
if output is not None:
493-
record_output(
494-
self._span, output, self._metric_attributes, duration, error=self._error
495-
)
496-
else:
497-
record_operation_duration(
498-
duration, self._metric_attributes, error=self._error
499-
)
530+
record_output(self._span, output, self._metric_attributes)
500531
self._span.end()
501532

502533

@@ -513,6 +544,7 @@ def trace_stream(
513544
"Status",
514545
"StatusCode",
515546
"add_input_event",
547+
"gen_ai_span",
516548
"meter",
517549
"output_attributes",
518550
"record_operation_duration",

0 commit comments

Comments
 (0)