Skip to content

Commit 09b3e2e

Browse files
authored
Python: Prevent pickle deserialization of untrusted HITL HTTP input (#4566)
* fix: prevent pickle deserialization of untrusted HITL input Add strip_pickle_markers() to sanitize HTTP input before it reaches pickle.loads() via the checkpoint decoding path. Applied as a 3-layer defence-in-depth: 1. _app.py: sanitize req.get_json() at the HTTP boundary 2. _workflow.py: sanitize in _deserialize_hitl_response() before decode 3. _serialization.py: sanitize in reconstruct_to_type() as final guard Any dict containing __pickled__ or __type__ markers from untrusted sources is replaced with None, blocking arbitrary code execution via crafted payloads to POST /workflow/respond/{instanceId}/{requestId}. Includes 12 new unit tests covering the sanitizer and end-to-end attack prevention. * refactor: address review concerns for pickle fix 1. Remove deserialize_value() fallback in _deserialize_hitl_response untrusted HITL data now returns as-is when no type hint is available, never flowing into pickle.loads(). 2. Move strip_pickle_markers() out of reconstruct_to_type() the function is general-purpose again; untrusted-data callers are responsible for sanitizing first (documented with NOTE comment). 3. Define _PICKLE_MARKER/_TYPE_MARKER as local constants with import-time assertions against core's values decouples from private names while failing loudly if core ever changes them. 4. Update tests to reflect new responsibility boundaries. * fix: simplify warning message and fix ruff RUF001 lint * fix: suppress pyright reportPrivateUsage on core marker imports * Lower marker-strip log from warning to debug to avoid log flooding * Replace assert with RuntimeError for marker sync checks (ruff S101) * Fix pyright and ruff CI errors in security fix - Use cast() for dict/list comprehensions in strip_pickle_markers (pyright) - type: ignore for narrowed dict return in _workflow.py (pyright) - Simplify marker imports: use core constants directly, remove local copies - Remove duplicate pyright ignore comment * Remove duplicate end-to-end test in TestStripPickleMarkers * Suppress mypy redundant-cast on list cast needed by pyright
1 parent 55fc882 commit 09b3e2e

File tree

4 files changed

+142
-11
lines changed

4 files changed

+142
-11
lines changed

python/packages/azurefunctions/agent_framework_azurefunctions/_app.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
from ._entities import create_agent_entity
4545
from ._errors import IncomingRequestError
4646
from ._orchestration import AgentOrchestrationContextType, AgentTask, AzureFunctionsAgentExecutor
47-
from ._serialization import deserialize_value, serialize_value
47+
from ._serialization import deserialize_value, serialize_value, strip_pickle_markers
4848
from ._workflow import (
4949
SOURCE_HITL_RESPONSE,
5050
SOURCE_ORCHESTRATOR,
@@ -515,6 +515,10 @@ async def send_hitl_response(req: func.HttpRequest, client: df.DurableOrchestrat
515515
except ValueError:
516516
return self._build_error_response("Request body must be valid JSON.")
517517

518+
# Sanitize untrusted HTTP input before it reaches pickle.loads().
519+
# See strip_pickle_markers() docstring for details on the attack vector.
520+
response_data = strip_pickle_markers(response_data)
521+
518522
# Send the response as an external event
519523
# The request_id is used as the event name for correlation
520524
await client.raise_event(

python/packages/azurefunctions/agent_framework_azurefunctions/_serialization.py

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,14 @@
2222
import logging
2323
from contextlib import suppress
2424
from dataclasses import is_dataclass
25-
from typing import Any
26-
27-
from agent_framework._workflows._checkpoint_encoding import decode_checkpoint_value, encode_checkpoint_value
25+
from typing import Any, cast
26+
27+
from agent_framework._workflows._checkpoint_encoding import (
28+
_PICKLE_MARKER, # pyright: ignore[reportPrivateUsage]
29+
_TYPE_MARKER, # pyright: ignore[reportPrivateUsage]
30+
decode_checkpoint_value,
31+
encode_checkpoint_value,
32+
)
2833
from pydantic import BaseModel
2934

3035
logger = logging.getLogger(__name__)
@@ -48,6 +53,41 @@ def resolve_type(type_key: str) -> type | None:
4853
return None
4954

5055

56+
# ============================================================================
57+
# Pickle marker sanitization (security)
58+
# ============================================================================
59+
60+
61+
def strip_pickle_markers(data: Any) -> Any:
62+
"""Recursively strip pickle/type markers from untrusted data.
63+
64+
The core checkpoint encoding uses ``__pickled__`` and ``__type__`` markers to
65+
roundtrip arbitrary Python objects via *pickle*. If an attacker crafts an
66+
HTTP payload that contains these markers, the data would flow into
67+
``pickle.loads()`` and enable **arbitrary code execution**.
68+
69+
This function walks the incoming data structure and replaces any ``dict``
70+
that contains either marker key with ``None``, neutralising the attack
71+
vector while leaving all other data untouched.
72+
73+
It **must** be called on every value that originates from an untrusted
74+
source (e.g. ``req.get_json()``) *before* the value is passed to
75+
``deserialize_value`` / ``decode_checkpoint_value``.
76+
"""
77+
if isinstance(data, dict):
78+
if _PICKLE_MARKER in data or _TYPE_MARKER in data:
79+
logger.debug("Stripped pickle/type markers from untrusted input.")
80+
return None
81+
typed_dict = cast(dict[str, Any], data)
82+
return {k: strip_pickle_markers(v) for k, v in typed_dict.items()}
83+
84+
if isinstance(data, list):
85+
typed_list = cast(list[Any], data) # type: ignore[redundant-cast]
86+
return [strip_pickle_markers(item) for item in typed_list]
87+
88+
return data
89+
90+
5191
# ============================================================================
5292
# Serialize / Deserialize
5393
# ============================================================================
@@ -117,7 +157,10 @@ def reconstruct_to_type(value: Any, target_type: type) -> Any:
117157
if not isinstance(value, dict):
118158
return value
119159

120-
# Try decoding if data has pickle markers (from checkpoint encoding)
160+
# Try decoding if data has pickle markers (from checkpoint encoding).
161+
# NOTE: This function is general-purpose. Callers that handle untrusted
162+
# data (e.g. HITL responses) MUST call strip_pickle_markers() before
163+
# passing data here. See _deserialize_hitl_response in _workflow.py.
121164
decoded = deserialize_value(value)
122165
if not isinstance(decoded, dict):
123166
return decoded

python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050

5151
from ._context import CapturingRunnerContext
5252
from ._orchestration import AzureFunctionsAgentExecutor
53-
from ._serialization import deserialize_value, reconstruct_to_type, resolve_type, serialize_value
53+
from ._serialization import deserialize_value, reconstruct_to_type, resolve_type, serialize_value, strip_pickle_markers
5454

5555
logger = logging.getLogger(__name__)
5656

@@ -961,6 +961,13 @@ def _deserialize_hitl_response(response_data: Any, response_type_str: str | None
961961
type(response_data).__name__,
962962
)
963963

964+
if response_data is None:
965+
return None
966+
967+
# Sanitize untrusted external input before deserialization.
968+
# HITL response data originates from an HTTP POST and must not contain
969+
# pickle/type markers that would reach pickle.loads().
970+
response_data = strip_pickle_markers(response_data)
964971
if response_data is None:
965972
return None
966973

@@ -969,7 +976,7 @@ def _deserialize_hitl_response(response_data: Any, response_type_str: str | None
969976
logger.debug("Response data is not a dict, returning as-is: %s", type(response_data).__name__)
970977
return response_data
971978

972-
# Try to deserialize using the type hint
979+
# Try to reconstruct using the type hint (Pydantic / dataclass)
973980
if response_type_str:
974981
response_type = resolve_type(response_type_str)
975982
if response_type:
@@ -979,6 +986,8 @@ def _deserialize_hitl_response(response_data: Any, response_type_str: str | None
979986
return result
980987
logger.warning("Could not resolve response type: %s", response_type_str)
981988

982-
# Fall back to generic deserialization
983-
logger.debug("Falling back to generic deserialization")
984-
return deserialize_value(response_data)
989+
# No type hint available - return the sanitized dict as-is.
990+
# We intentionally do NOT call deserialize_value() here because HITL
991+
# response data is untrusted and must never flow into pickle.loads().
992+
logger.debug("No type hint; returning sanitized data as-is")
993+
return response_data # type: ignore[reportUnknownVariableType]

python/packages/azurefunctions/tests/test_func_utils.py

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
deserialize_value,
2222
reconstruct_to_type,
2323
serialize_value,
24+
strip_pickle_markers,
2425
)
2526

2627

@@ -353,7 +354,11 @@ class Feedback:
353354
assert result.comment == "Great"
354355

355356
def test_reconstruct_from_checkpoint_markers(self) -> None:
356-
"""Test that data with checkpoint markers is decoded via deserialize_value."""
357+
"""Test that data with checkpoint markers is decoded via deserialize_value.
358+
359+
reconstruct_to_type is general-purpose and handles trusted checkpoint
360+
data. Untrusted HITL callers must call strip_pickle_markers() first.
361+
"""
357362
original = SampleData(value=99, name="marker-test")
358363
encoded = serialize_value(original)
359364

@@ -372,3 +377,73 @@ class Unrelated:
372377
result = reconstruct_to_type(data, Unrelated)
373378

374379
assert result == data
380+
381+
def test_reconstruct_strips_injected_pickle_markers(self) -> None:
382+
"""End-to-end: strip_pickle_markers + reconstruct_to_type blocks attack.
383+
384+
This mirrors the real HITL flow where callers sanitize before reconstruction.
385+
"""
386+
malicious = {"__pickled__": "gASVDgAAAAAAAACMBHRlc3SULg==", "__type__": "builtins:str"}
387+
sanitized = strip_pickle_markers(malicious)
388+
result = reconstruct_to_type(sanitized, str)
389+
assert result is None
390+
391+
392+
class TestStripPickleMarkers:
393+
"""Security tests for strip_pickle_markers — the defence-in-depth layer
394+
that prevents untrusted HTTP input from reaching pickle.loads()."""
395+
396+
def test_strips_top_level_pickle_marker(self) -> None:
397+
"""A dict containing __pickled__ must be replaced with None."""
398+
data = {"__pickled__": "PAYLOAD", "__type__": "os:system"}
399+
assert strip_pickle_markers(data) is None
400+
401+
def test_strips_top_level_type_marker_only(self) -> None:
402+
"""Even __type__ alone (without __pickled__) must be neutralised."""
403+
data = {"__type__": "os:system", "other": "value"}
404+
assert strip_pickle_markers(data) is None
405+
406+
def test_strips_nested_pickle_marker(self) -> None:
407+
"""Pickle markers nested inside a dict must be neutralised."""
408+
data = {"safe": "value", "nested": {"__pickled__": "PAYLOAD", "__type__": "os:system"}}
409+
result = strip_pickle_markers(data)
410+
assert result == {"safe": "value", "nested": None}
411+
412+
def test_strips_pickle_marker_in_list(self) -> None:
413+
"""Pickle markers inside a list element must be neutralised."""
414+
data = [{"__pickled__": "PAYLOAD"}, "safe"]
415+
result = strip_pickle_markers(data)
416+
assert result == [None, "safe"]
417+
418+
def test_strips_deeply_nested_marker(self) -> None:
419+
"""Deeply nested pickle markers must be neutralised."""
420+
data = {"a": {"b": {"c": {"__pickled__": "deep"}}}}
421+
result = strip_pickle_markers(data)
422+
assert result == {"a": {"b": {"c": None}}}
423+
424+
def test_preserves_safe_dict(self) -> None:
425+
"""Dicts without pickle markers must be left untouched."""
426+
data = {"approved": True, "reason": "Looks good"}
427+
assert strip_pickle_markers(data) == data
428+
429+
def test_preserves_primitives(self) -> None:
430+
"""Primitive values must pass through unchanged."""
431+
assert strip_pickle_markers("hello") == "hello"
432+
assert strip_pickle_markers(42) == 42
433+
assert strip_pickle_markers(None) is None
434+
assert strip_pickle_markers(True) is True
435+
436+
def test_preserves_safe_list(self) -> None:
437+
"""Lists without pickle markers must be left untouched."""
438+
data = [1, "two", {"key": "value"}]
439+
assert strip_pickle_markers(data) == data
440+
441+
def test_mixed_safe_and_malicious(self) -> None:
442+
"""Only the malicious entries should be stripped; safe entries remain."""
443+
data = {
444+
"user_input": "hello",
445+
"evil": {"__pickled__": "PAYLOAD", "__type__": "os:system"},
446+
"count": 42,
447+
}
448+
result = strip_pickle_markers(data)
449+
assert result == {"user_input": "hello", "evil": None, "count": 42}

0 commit comments

Comments
 (0)