Skip to content

Commit 03d5df8

Browse files
committed
Merge branch 'main' into giles/tool-rich-content-results
2 parents b086455 + b6a1315 commit 03d5df8

File tree

4 files changed

+373
-7
lines changed

4 files changed

+373
-7
lines changed

python/packages/azurefunctions/agent_framework_azurefunctions/_app.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import re
1515
import uuid
1616
from collections.abc import Callable, Mapping
17+
from copy import deepcopy
1718
from dataclasses import dataclass
1819
from datetime import datetime, timezone
1920
from typing import TYPE_CHECKING, Any, TypeVar, cast
@@ -58,6 +59,11 @@
5859
HandlerT = TypeVar("HandlerT", bound=Callable[..., Any])
5960

6061

62+
def _create_state_snapshot(state: dict[str, Any]) -> dict[str, Any]:
63+
"""Create a deep copy of the deserialized state for later diffing."""
64+
return deepcopy(state)
65+
66+
6167
@dataclass
6268
class AgentMetadata:
6369
"""Metadata for a registered agent.
@@ -306,7 +312,7 @@ async def run() -> dict[str, Any]:
306312
deserialized_state: dict[str, Any] = {
307313
str(k): deserialize_value(v) for k, v in shared_state_snapshot.items()
308314
}
309-
original_snapshot: dict[str, Any] = dict(deserialized_state)
315+
original_snapshot = _create_state_snapshot(deserialized_state)
310316
shared_state.import_state(deserialized_state)
311317

312318
if is_hitl_response:
@@ -339,9 +345,10 @@ async def run() -> dict[str, Any]:
339345
deletes: set[str] = original_keys - current_keys
340346

341347
# Updates = keys in current that are new or have different values
342-
updates = {
343-
k: v for k, v in current_state.items() if k not in original_snapshot or original_snapshot[k] != v
344-
}
348+
updates: dict[str, Any] = {}
349+
for key in current_keys:
350+
if key not in original_keys or current_state[key] != original_snapshot.get(key):
351+
updates[key] = current_state[key]
345352

346353
# Drain messages and events from runner context
347354
sent_messages = await runner_context.drain_messages()

python/packages/azurefunctions/tests/test_app.py

Lines changed: 282 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
from agent_framework_azurefunctions import AgentFunctionApp
2828
from agent_framework_azurefunctions._entities import create_agent_entity
29+
from agent_framework_azurefunctions._workflow import SOURCE_ORCHESTRATOR
2930

3031
FuncT = TypeVar("FuncT", bound=Callable[..., Any])
3132

@@ -1441,5 +1442,286 @@ def test_build_status_url_handles_trailing_slash(self) -> None:
14411442
assert "instance-456" in url
14421443

14431444

1445+
def _compute_state_updates(original_snapshot: dict[str, Any], current_state: dict[str, Any]) -> dict[str, Any]:
1446+
"""Compute state updates by comparing current state against the original snapshot.
1447+
1448+
This mirrors the inlined logic in ``_app.py``'s ``executor_activity.run()``.
1449+
"""
1450+
original_keys = set(original_snapshot.keys())
1451+
current_keys = set(current_state.keys())
1452+
updates: dict[str, Any] = {}
1453+
for key in current_keys:
1454+
if key not in original_keys or current_state[key] != original_snapshot.get(key):
1455+
updates[key] = current_state[key]
1456+
return updates
1457+
1458+
1459+
class TestStateSnapshotDiff:
1460+
"""Test suite for state snapshot diffing in activity execution.
1461+
1462+
The activity executor snapshots state before execution and diffs against the
1463+
post-execution state to determine which keys were updated. These tests exercise
1464+
the production snapshot helper and the state-update diffing logic to ensure that
1465+
in-place mutations to nested objects (dicts, lists) are correctly detected as changes.
1466+
"""
1467+
1468+
def test_nested_dict_mutation_detected_in_diff(self) -> None:
1469+
"""Test that mutating values inside a nested dict appears in the diff."""
1470+
from agent_framework._workflows._state import State
1471+
1472+
from agent_framework_azurefunctions._app import _create_state_snapshot
1473+
1474+
deserialized_state: dict[str, Any] = {
1475+
"Local.config": {"code": "", "enabled": False},
1476+
"simple_key": "simple_value",
1477+
}
1478+
1479+
original_snapshot = _create_state_snapshot(deserialized_state)
1480+
1481+
shared_state = State()
1482+
shared_state.import_state(deserialized_state)
1483+
1484+
config = shared_state.get("Local.config")
1485+
config["code"] = "SOMECODEXXX"
1486+
config["enabled"] = True
1487+
1488+
shared_state.commit()
1489+
current_state = shared_state.export_state()
1490+
1491+
updates = _compute_state_updates(original_snapshot, current_state)
1492+
1493+
assert "Local.config" in updates
1494+
assert updates["Local.config"]["code"] == "SOMECODEXXX"
1495+
assert updates["Local.config"]["enabled"] is True
1496+
1497+
def test_new_key_in_nested_dict_detected_in_diff(self) -> None:
1498+
"""Test that adding a key to a nested dict appears in the diff."""
1499+
from agent_framework._workflows._state import State
1500+
1501+
from agent_framework_azurefunctions._app import _create_state_snapshot
1502+
1503+
deserialized_state: dict[str, Any] = {
1504+
"Local.data": {"existing": "value"},
1505+
}
1506+
1507+
original_snapshot = _create_state_snapshot(deserialized_state)
1508+
1509+
shared_state = State()
1510+
shared_state.import_state(deserialized_state)
1511+
1512+
data = shared_state.get("Local.data")
1513+
data["code"] = "NEW_CODE"
1514+
1515+
shared_state.commit()
1516+
current_state = shared_state.export_state()
1517+
1518+
updates = _compute_state_updates(original_snapshot, current_state)
1519+
1520+
assert "Local.data" in updates
1521+
assert updates["Local.data"]["code"] == "NEW_CODE"
1522+
1523+
def test_nested_list_mutation_detected_in_diff(self) -> None:
1524+
"""Test that appending to a nested list appears in the diff."""
1525+
from agent_framework._workflows._state import State
1526+
1527+
from agent_framework_azurefunctions._app import _create_state_snapshot
1528+
1529+
deserialized_state: dict[str, Any] = {
1530+
"Local.items": [1, 2, 3],
1531+
}
1532+
1533+
original_snapshot = _create_state_snapshot(deserialized_state)
1534+
1535+
shared_state = State()
1536+
shared_state.import_state(deserialized_state)
1537+
1538+
items = shared_state.get("Local.items")
1539+
items.append(4)
1540+
1541+
shared_state.commit()
1542+
current_state = shared_state.export_state()
1543+
1544+
updates = _compute_state_updates(original_snapshot, current_state)
1545+
1546+
assert "Local.items" in updates
1547+
assert updates["Local.items"] == [1, 2, 3, 4]
1548+
1549+
def test_new_top_level_key_detected_in_diff(self) -> None:
1550+
"""Test that setting a new top-level key appears in the diff."""
1551+
from agent_framework._workflows._state import State
1552+
1553+
from agent_framework_azurefunctions._app import _create_state_snapshot
1554+
1555+
deserialized_state: dict[str, Any] = {
1556+
"existing": "value",
1557+
}
1558+
1559+
original_snapshot = _create_state_snapshot(deserialized_state)
1560+
1561+
shared_state = State()
1562+
shared_state.import_state(deserialized_state)
1563+
1564+
shared_state.set("Local.code", "SOMECODEXXX")
1565+
1566+
shared_state.commit()
1567+
current_state = shared_state.export_state()
1568+
1569+
updates = _compute_state_updates(original_snapshot, current_state)
1570+
1571+
assert "Local.code" in updates
1572+
assert updates["Local.code"] == "SOMECODEXXX"
1573+
1574+
def test_unchanged_nested_state_produces_empty_diff(self) -> None:
1575+
"""Test that unmodified nested state produces no updates."""
1576+
from agent_framework._workflows._state import State
1577+
1578+
from agent_framework_azurefunctions._app import _create_state_snapshot
1579+
1580+
deserialized_state: dict[str, Any] = {
1581+
"Local.config": {"code": "existing", "enabled": True},
1582+
"simple_key": "simple_value",
1583+
}
1584+
1585+
original_snapshot = _create_state_snapshot(deserialized_state)
1586+
1587+
shared_state = State()
1588+
shared_state.import_state(deserialized_state)
1589+
1590+
# No mutations performed
1591+
shared_state.commit()
1592+
current_state = shared_state.export_state()
1593+
1594+
updates = _compute_state_updates(original_snapshot, current_state)
1595+
1596+
assert updates == {}
1597+
1598+
def test_shallow_copy_would_miss_nested_mutations(self) -> None:
1599+
"""Regression test: a shallow copy (dict()) shares nested refs, hiding mutations.
1600+
1601+
This reproduces the original bug from #4500 where ``dict(deserialized_state)``
1602+
was used instead of ``copy.deepcopy()``. With a shallow copy the snapshot and
1603+
the live state share nested objects, so in-place mutations appear in both and
1604+
the diff produces an empty update set.
1605+
"""
1606+
from agent_framework._workflows._state import State
1607+
1608+
deserialized_state: dict[str, Any] = {
1609+
"Local.config": {"code": "", "enabled": False},
1610+
}
1611+
1612+
# Shallow copy (the OLD, buggy behaviour)
1613+
shallow_snapshot = dict(deserialized_state)
1614+
1615+
shared_state = State()
1616+
shared_state.import_state(deserialized_state)
1617+
1618+
config = shared_state.get("Local.config")
1619+
config["code"] = "SOMECODEXXX"
1620+
config["enabled"] = True
1621+
1622+
shared_state.commit()
1623+
current_state = shared_state.export_state()
1624+
1625+
# With a shallow copy the mutation leaks into the snapshot → empty diff
1626+
updates_shallow = _compute_state_updates(shallow_snapshot, current_state)
1627+
assert updates_shallow == {}, "shallow copy should miss nested mutations (demonstrating the bug)"
1628+
1629+
def test_create_state_snapshot_isolates_nested_objects(self) -> None:
1630+
"""Verify _create_state_snapshot produces a deep copy that is mutation-proof.
1631+
1632+
This ensures the production snapshot helper is not equivalent to ``dict()``
1633+
and will correctly isolate nested objects so that later mutations are detected.
1634+
"""
1635+
from agent_framework_azurefunctions._app import _create_state_snapshot
1636+
1637+
original: dict[str, Any] = {
1638+
"nested_dict": {"a": 1},
1639+
"nested_list": [1, 2, 3],
1640+
}
1641+
1642+
snapshot = _create_state_snapshot(original)
1643+
1644+
# Mutate the originals in place
1645+
original["nested_dict"]["a"] = 999
1646+
original["nested_list"].append(4)
1647+
1648+
# Snapshot must be unaffected
1649+
assert snapshot["nested_dict"]["a"] == 1
1650+
assert snapshot["nested_list"] == [1, 2, 3]
1651+
1652+
def test_executor_activity_detects_nested_state_mutations(self) -> None:
1653+
"""Integration test: the full activity wrapper detects nested mutations.
1654+
1655+
This exercises the actual executor_activity function registered by
1656+
_setup_executor_activity to verify the production code path uses
1657+
_create_state_snapshot (deep copy) rather than dict() (shallow copy).
1658+
If the implementation regressed to using a shallow copy such as
1659+
``dict(deserialized_state)``, this test would fail because in-place
1660+
mutations would leak into the snapshot and produce an empty diff.
1661+
"""
1662+
mock_executor = Mock()
1663+
mock_executor.id = "test-exec"
1664+
1665+
async def mutate_nested_state(
1666+
message: Any,
1667+
source_executor_ids: Any,
1668+
state: Any,
1669+
runner_context: Any,
1670+
) -> None:
1671+
config = state.get("Local.config")
1672+
config["code"] = "MUTATED"
1673+
config["enabled"] = True
1674+
state.commit()
1675+
1676+
mock_executor.execute = AsyncMock(side_effect=mutate_nested_state)
1677+
1678+
mock_workflow = Mock()
1679+
mock_workflow.executors = {"test-exec": mock_executor}
1680+
1681+
# Capture the activity function by making decorators pass-through
1682+
captured_activity: dict[str, Any] = {}
1683+
1684+
def passthrough_function_name(name: str) -> Callable[[FuncT], FuncT]:
1685+
def decorator(fn: FuncT) -> FuncT:
1686+
captured_activity["fn"] = fn
1687+
return fn
1688+
1689+
return decorator
1690+
1691+
def passthrough_activity_trigger(input_name: str) -> Callable[[FuncT], FuncT]:
1692+
def decorator(fn: FuncT) -> FuncT:
1693+
return fn
1694+
1695+
return decorator
1696+
1697+
with (
1698+
patch.object(AgentFunctionApp, "function_name", side_effect=passthrough_function_name),
1699+
patch.object(AgentFunctionApp, "activity_trigger", side_effect=passthrough_activity_trigger),
1700+
patch.object(AgentFunctionApp, "_setup_workflow_orchestration"),
1701+
):
1702+
AgentFunctionApp(workflow=mock_workflow)
1703+
1704+
assert "fn" in captured_activity, "activity function was not captured"
1705+
1706+
# Call the activity with nested state that the executor will mutate
1707+
input_data = json.dumps({
1708+
"message": "test",
1709+
"shared_state_snapshot": {
1710+
"Local.config": {"code": "", "enabled": False},
1711+
},
1712+
"source_executor_ids": [SOURCE_ORCHESTRATOR],
1713+
})
1714+
1715+
result = json.loads(captured_activity["fn"](input_data))
1716+
1717+
# The deep copy snapshot must detect the in-place nested mutations
1718+
assert "Local.config" in result["shared_state_updates"], (
1719+
"nested mutation not detected — snapshot may be using shallow copy"
1720+
)
1721+
updated_config = result["shared_state_updates"]["Local.config"]
1722+
assert updated_config["code"] == "MUTATED"
1723+
assert updated_config["enabled"] is True
1724+
1725+
14441726
if __name__ == "__main__":
14451727
pytest.main([__file__, "-v", "--tb=short"])

python/packages/bedrock/agent_framework_bedrock/_chat_client.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -405,11 +405,16 @@ def _prepare_options(
405405

406406
tool_config = self._prepare_tools(options.get("tools"))
407407
if tool_mode := validate_tool_mode(options.get("tool_choice")):
408-
tool_config = tool_config or {}
409408
match tool_mode.get("mode"):
410-
case "auto" | "none":
411-
tool_config["toolChoice"] = {tool_mode.get("mode"): {}}
409+
case "none":
410+
# Bedrock doesn't support toolChoice "none".
411+
# Omit toolConfig entirely so the model won't attempt tool calls.
412+
tool_config = None
413+
case "auto":
414+
tool_config = tool_config or {}
415+
tool_config["toolChoice"] = {"auto": {}}
412416
case "required":
417+
tool_config = tool_config or {}
413418
if required_name := tool_mode.get("required_function_name"):
414419
tool_config["toolChoice"] = {"tool": {"name": required_name}}
415420
else:

0 commit comments

Comments
 (0)