Skip to content

Commit 9678094

Browse files
committed
fix(plugins): forward **kwargs to _log_event in all BigQuery plugin callbacks
Only model callbacks (before_model_callback, after_model_callback) were passing **kwargs through to _log_event(), while all other callbacks silently dropped them. This meant custom attributes like `customer_id` passed via **kwargs were only logged for model events, not for agent, tool, run, or user message events. Updated the following callbacks to forward **kwargs to _log_event(): - on_user_message_callback - before_agent_callback / after_agent_callback - on_model_error_callback - before_tool_callback / after_tool_callback - on_tool_error_callback - before_run_callback / after_run_callback This ensures custom metadata is consistently serialized into the BigQuery `attributes` JSON field for all event types. Fixes #4330
1 parent ec660ed commit 9678094

File tree

1 file changed

+22
-4
lines changed

1 file changed

+22
-4
lines changed

src/google/adk/plugins/bigquery_agent_analytics_plugin.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2162,11 +2162,13 @@ async def on_user_message_callback(
21622162
Args:
21632163
invocation_context: The context of the current invocation.
21642164
user_message: The message content received from the user.
2165+
**kwargs: Additional keyword arguments (e.g., custom attributes).
21652166
"""
21662167
await self._log_event(
21672168
"USER_MESSAGE_RECEIVED",
21682169
CallbackContext(invocation_context),
21692170
raw_content=user_message,
2171+
**kwargs,
21702172
)
21712173

21722174
async def on_state_change_callback(
@@ -2197,10 +2199,13 @@ async def before_run_callback(
21972199
21982200
Args:
21992201
invocation_context: The context of the current invocation.
2202+
**kwargs: Additional keyword arguments (e.g., custom attributes).
22002203
"""
22012204
await self._ensure_started()
22022205
await self._log_event(
2203-
"INVOCATION_STARTING", CallbackContext(invocation_context)
2206+
"INVOCATION_STARTING",
2207+
CallbackContext(invocation_context),
2208+
**kwargs,
22042209
)
22052210

22062211
async def after_run_callback(
@@ -2210,9 +2215,12 @@ async def after_run_callback(
22102215
22112216
Args:
22122217
invocation_context: The context of the current invocation.
2218+
**kwargs: Additional keyword arguments (e.g., custom attributes).
22132219
"""
22142220
await self._log_event(
2215-
"INVOCATION_COMPLETED", CallbackContext(invocation_context)
2221+
"INVOCATION_COMPLETED",
2222+
CallbackContext(invocation_context),
2223+
**kwargs,
22162224
)
22172225
# Ensure all logs are flushed before the agent returns
22182226
await self.flush()
@@ -2225,13 +2233,15 @@ async def before_agent_callback(
22252233
Args:
22262234
agent: The agent instance.
22272235
callback_context: The callback context.
2236+
**kwargs: Additional keyword arguments (e.g., custom attributes).
22282237
"""
22292238
TraceManager.init_trace(callback_context)
22302239
TraceManager.push_span(callback_context, "agent")
22312240
await self._log_event(
22322241
"AGENT_STARTING",
22332242
callback_context,
22342243
raw_content=getattr(agent, "instruction", ""),
2244+
**kwargs,
22352245
)
22362246

22372247
async def after_agent_callback(
@@ -2242,6 +2252,7 @@ async def after_agent_callback(
22422252
Args:
22432253
agent: The agent instance.
22442254
callback_context: The callback context.
2255+
**kwargs: Additional keyword arguments (e.g., custom attributes).
22452256
"""
22462257
span_id, duration = TraceManager.pop_span()
22472258
# When popping, the current stack now points to parent.
@@ -2255,6 +2266,7 @@ async def after_agent_callback(
22552266
latency_ms=duration,
22562267
span_id_override=span_id,
22572268
parent_span_id_override=parent_span_id,
2269+
**kwargs,
22582270
)
22592271

22602272
async def before_model_callback(
@@ -2436,7 +2448,7 @@ async def on_model_error_callback(
24362448
Args:
24372449
callback_context: The callback context.
24382450
error: The exception that occurred.
2439-
**kwargs: Additional arguments.
2451+
**kwargs: Additional keyword arguments (e.g., custom attributes).
24402452
"""
24412453
span_id, duration = TraceManager.pop_span()
24422454
parent_span_id, _ = TraceManager.get_current_span_and_parent()
@@ -2447,6 +2459,7 @@ async def on_model_error_callback(
24472459
latency_ms=duration,
24482460
span_id_override=span_id,
24492461
parent_span_id_override=parent_span_id,
2462+
**kwargs,
24502463
)
24512464

24522465
async def before_tool_callback(
@@ -2463,6 +2476,7 @@ async def before_tool_callback(
24632476
tool: The tool being executed.
24642477
tool_args: The arguments passed to the tool.
24652478
tool_context: The tool context.
2479+
**kwargs: Additional keyword arguments (e.g., custom attributes).
24662480
"""
24672481
args_truncated, is_truncated = _recursive_smart_truncate(
24682482
tool_args, self.config.max_content_length
@@ -2474,6 +2488,7 @@ async def before_tool_callback(
24742488
tool_context,
24752489
raw_content=content_dict,
24762490
is_truncated=is_truncated,
2491+
**kwargs,
24772492
)
24782493

24792494
async def after_tool_callback(
@@ -2492,6 +2507,7 @@ async def after_tool_callback(
24922507
tool_args: The arguments passed to the tool.
24932508
tool_context: The tool context.
24942509
result: The response from the tool.
2510+
**kwargs: Additional keyword arguments (e.g., custom attributes).
24952511
"""
24962512
resp_truncated, is_truncated = _recursive_smart_truncate(
24972513
result, self.config.max_content_length
@@ -2508,6 +2524,7 @@ async def after_tool_callback(
25082524
latency_ms=duration,
25092525
span_id_override=span_id,
25102526
parent_span_id_override=parent_span_id,
2527+
**kwargs,
25112528
)
25122529

25132530
if tool_context.actions.state_delta:
@@ -2533,7 +2550,7 @@ async def on_tool_error_callback(
25332550
tool_args: The arguments passed to the tool.
25342551
tool_context: The tool context.
25352552
error: The exception that occurred.
2536-
**kwargs: Additional arguments.
2553+
**kwargs: Additional keyword arguments (e.g., custom attributes).
25372554
"""
25382555
args_truncated, is_truncated = _recursive_smart_truncate(
25392556
tool_args, self.config.max_content_length
@@ -2547,4 +2564,5 @@ async def on_tool_error_callback(
25472564
error_message=str(error),
25482565
is_truncated=is_truncated,
25492566
latency_ms=duration,
2567+
**kwargs,
25502568
)

0 commit comments

Comments
 (0)