diff --git a/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py b/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py index 534022cc133..1ef866486ec 100644 --- a/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py @@ -41,6 +41,9 @@ create_websocket_passthrough_route, websocket_passthrough_request, ) +from litellm.types.passthrough_endpoints.pass_through_endpoints import ( + LITELLM_PASS_THROUGH_CUSTOM_BODY_STATE_KEY, +) from litellm.proxy.utils import is_known_model from litellm.proxy.vector_store_endpoints.utils import ( is_allowed_to_call_vector_store_endpoint, @@ -1086,11 +1089,11 @@ async def bedrock_proxy_route( is_streaming_request=is_streaming_request, _forward_headers=True, ) # dynamically construct pass-through endpoint based on incoming path + setattr(request.state, LITELLM_PASS_THROUGH_CUSTOM_BODY_STATE_KEY, data) received_value = await endpoint_func( request, fastapi_response, user_api_key_dict, - custom_body=data, # type: ignore ) return received_value diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py index 4f68c92b9d9..d582240395f 100644 --- a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -61,6 +61,7 @@ from litellm.types.llms.custom_http import httpxSpecialProvider from litellm.types.passthrough_endpoints.pass_through_endpoints import ( EndpointType, + LITELLM_PASS_THROUGH_CUSTOM_BODY_STATE_KEY, PassthroughStandardLoggingPayload, ) @@ -392,6 +393,7 @@ async def non_streaming_http_request_handler( headers: dict, requested_query_params: Optional[dict] = None, _parsed_body: Optional[dict] = None, + forward_multipart: bool = False, ) -> httpx.Response: """ Handle non-streaming HTTP requests @@ -407,10 +409,12 @@ async def non_streaming_http_request_handler( ) elif ( HttpPassThroughEndpointHelpers.is_multipart(request) is True - and not _parsed_body + and forward_multipart ): - # Only use multipart handler if we don't have a parsed body - # (parsed body means it was JSON despite multipart content-type header) + # Forward multipart via make_multipart_http_request even when _parsed_body is + # non-empty (pass_through_request always injects litellm_logging_obj, etc.). + # forward_multipart is False when custom_body was supplied (JSON body despite + # multipart content-type) — those requests use the generic json= path. return await HttpPassThroughEndpointHelpers.make_multipart_http_request( request=request, async_client=async_client, @@ -449,6 +453,7 @@ async def make_multipart_http_request( url: httpx.URL, headers: dict, requested_query_params: Optional[dict] = None, + stream: bool = False, ) -> httpx.Response: """Process multipart/form-data requests, handling both files and form fields""" form_data = await request.form() @@ -457,10 +462,10 @@ async def make_multipart_http_request( for field_name, field_value in form_data.items(): if isinstance(field_value, (StarletteUploadFile, UploadFile)): - files[ - field_name - ] = await HttpPassThroughEndpointHelpers._build_request_files_from_upload_file( - upload_file=field_value + files[field_name] = ( + await HttpPassThroughEndpointHelpers._build_request_files_from_upload_file( + upload_file=field_value + ) ) else: form_data_dict[field_name] = field_value @@ -470,7 +475,19 @@ async def make_multipart_http_request( headers_copy = headers.copy() headers_copy.pop("content-type", None) - response = await async_client.request( + # httpx.AsyncClient.request() does not accept stream=; use send() for streaming. + if stream: + req = async_client.build_request( + request.method, + url, + headers=headers_copy, + params=requested_query_params, + files=files, + data=form_data_dict, + ) + return await async_client.send(req, stream=True) + + return await async_client.request( method=request.method, url=url, headers=headers_copy, @@ -478,7 +495,6 @@ async def make_multipart_http_request( files=files, data=form_data_dict, ) - return response @staticmethod def _init_kwargs_for_pass_through_endpoint( @@ -537,9 +553,9 @@ def _init_kwargs_for_pass_through_endpoint( "passthrough_logging_payload": passthrough_logging_payload, } - logging_obj.model_call_details[ - "passthrough_logging_payload" - ] = passthrough_logging_payload + logging_obj.model_call_details["passthrough_logging_payload"] = ( + passthrough_logging_payload + ) return kwargs @@ -803,15 +819,27 @@ async def pass_through_request( # noqa: PLR0915 ) if stream: - req = async_client.build_request( - "POST", - url, - json=_parsed_body, - params=requested_query_params, - headers=headers, - ) + if is_multipart: + response = ( + await HttpPassThroughEndpointHelpers.make_multipart_http_request( + request=request, + async_client=async_client, + url=url, + headers=headers, + requested_query_params=requested_query_params, + stream=True, + ) + ) + else: + req = async_client.build_request( + "POST", + url, + json=_parsed_body, + params=requested_query_params, + headers=headers, + ) - response = await async_client.send(req, stream=stream) + response = await async_client.send(req, stream=stream) try: response.raise_for_status() @@ -845,6 +873,7 @@ async def pass_through_request( # noqa: PLR0915 headers=headers, requested_query_params=requested_query_params, _parsed_body=_parsed_body, + forward_multipart=is_multipart, ) ) verbose_proxy_logger.debug("response.headers= %s", response.headers) @@ -1106,9 +1135,6 @@ async def endpoint_func( # type: ignore fastapi_response: Response, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), subpath: str = "", # captures sub-paths when include_subpath=True - custom_body: Optional[ - dict - ] = None, # accepted for signature compatibility with URL-based path; not forwarded because chat_completion_pass_through_endpoint does not support it ): return await chat_completion_pass_through_endpoint( fastapi_response=fastapi_response, @@ -1125,9 +1151,6 @@ async def endpoint_func( # type: ignore fastapi_response: Response, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), subpath: str = "", # captures sub-paths when include_subpath=True - custom_body: Optional[ - dict - ] = None, # caller-supplied body takes precedence over request-parsed body ): from litellm.proxy.pass_through_endpoints.pass_through_endpoints import ( InitPassThroughEndpointHelpers, @@ -1203,28 +1226,40 @@ async def endpoint_func( # type: ignore ) if query_params: final_query_params.update(query_params) - # Caller-supplied custom_body takes precedence over the request-parsed body + # Programmatic callers set LITELLM_PASS_THROUGH_CUSTOM_BODY_STATE_KEY on + # request.state (see Bedrock proxy). Parsed JSON envelope otherwise. + state_custom_body: Optional[dict] = getattr( + request.state, + LITELLM_PASS_THROUGH_CUSTOM_BODY_STATE_KEY, + None, + ) final_custom_body: Optional[dict] = None - if custom_body is not None: - final_custom_body = custom_body + if isinstance(state_custom_body, dict): + final_custom_body = state_custom_body elif isinstance(custom_body_data, dict): final_custom_body = custom_body_data - return await pass_through_request( # type: ignore - request=request, - target=full_target, - custom_headers=headers_dict, - user_api_key_dict=user_api_key_dict, - forward_headers=cast(Optional[bool], param_forward_headers), - merge_query_params=cast(Optional[bool], param_merge_query_params), - query_params=final_query_params, - default_query_params=cast(Optional[dict], param_default_query_params), - stream=is_streaming_request or stream, - custom_body=final_custom_body, - cost_per_request=cast(Optional[float], param_cost_per_request), - custom_llm_provider=custom_llm_provider, - guardrails_config=cast(Optional[dict], param_guardrails), - ) + try: + return await pass_through_request( # type: ignore + request=request, + target=full_target, + custom_headers=headers_dict, + user_api_key_dict=user_api_key_dict, + forward_headers=cast(Optional[bool], param_forward_headers), + merge_query_params=cast(Optional[bool], param_merge_query_params), + query_params=final_query_params, + default_query_params=cast( + Optional[dict], param_default_query_params + ), + stream=is_streaming_request or stream, + custom_body=final_custom_body, + cost_per_request=cast(Optional[float], param_cost_per_request), + custom_llm_provider=custom_llm_provider, + guardrails_config=cast(Optional[dict], param_guardrails), + ) + finally: + if hasattr(request.state, LITELLM_PASS_THROUGH_CUSTOM_BODY_STATE_KEY): + delattr(request.state, LITELLM_PASS_THROUGH_CUSTOM_BODY_STATE_KEY) return endpoint_func @@ -1468,9 +1503,9 @@ async def forward_client_to_upstream() -> None: ) if extracted_model: kwargs["model"] = extracted_model - kwargs[ - "custom_llm_provider" - ] = "vertex_ai-language-models" + kwargs["custom_llm_provider"] = ( + "vertex_ai-language-models" + ) # Update logging object with correct model logging_obj.model = extracted_model logging_obj.model_call_details[ @@ -1536,9 +1571,9 @@ async def forward_upstream_to_client() -> None: # Update logging object with correct model logging_obj.model = extracted_model logging_obj.model_call_details["model"] = extracted_model - logging_obj.model_call_details[ - "custom_llm_provider" - ] = "vertex_ai_language_models" + logging_obj.model_call_details["custom_llm_provider"] = ( + "vertex_ai_language_models" + ) verbose_proxy_logger.debug( f"WebSocket passthrough ({endpoint}): Successfully extracted model '{extracted_model}' and set provider to 'vertex_ai' from server setup response" ) diff --git a/litellm/types/passthrough_endpoints/pass_through_endpoints.py b/litellm/types/passthrough_endpoints/pass_through_endpoints.py index c99775f2a6c..4a07fa5e849 100644 --- a/litellm/types/passthrough_endpoints/pass_through_endpoints.py +++ b/litellm/types/passthrough_endpoints/pass_through_endpoints.py @@ -3,6 +3,10 @@ from typing_extensions import TypedDict +# Request.state key for programmatic pass-through callers (e.g. Bedrock proxy) that attach +# JSON without a FastAPI `custom_body` parameter (which would consume the HTTP body). +LITELLM_PASS_THROUGH_CUSTOM_BODY_STATE_KEY = "litellm_pass_through_custom_body" + class EndpointType(str, Enum): VERTEX_AI = "vertex-ai" diff --git a/tests/test_litellm/proxy/pass_through_endpoints/test_pass_through_endpoints.py b/tests/test_litellm/proxy/pass_through_endpoints/test_pass_through_endpoints.py index ea68e8566a0..8c1ebe85d0a 100644 --- a/tests/test_litellm/proxy/pass_through_endpoints/test_pass_through_endpoints.py +++ b/tests/test_litellm/proxy/pass_through_endpoints/test_pass_through_endpoints.py @@ -2,6 +2,7 @@ import os import sys from io import BytesIO +from types import SimpleNamespace from unittest.mock import AsyncMock, MagicMock, patch import httpx @@ -16,6 +17,7 @@ from litellm.proxy.pass_through_endpoints.pass_through_endpoints import ( HttpPassThroughEndpointHelpers, + LITELLM_PASS_THROUGH_CUSTOM_BODY_STATE_KEY, pass_through_request, ) from litellm.proxy.pass_through_endpoints.success_handler import ( @@ -193,6 +195,47 @@ async def test_make_multipart_http_request_removes_content_type_header(): assert "content-type" in original_headers +@pytest.mark.asyncio +async def test_non_streaming_http_request_handler_multipart_with_non_empty_parsed_body(): + """ + Regression: pass_through_request injects litellm_logging_obj into _parsed_body before + forwarding. Multipart uploads must still use files=, not json=_parsed_body. + """ + request = MagicMock(spec=Request) + request.method = "POST" + request.headers = Headers( + {"content-type": "multipart/form-data; boundary=------------------------test"} + ) + + file_content = b"test file content" + file = BytesIO(file_content) + upload_headers = Headers({"content-type": "text/plain"}) + upload_file = UploadFile(file=file, filename="test.txt", headers=upload_headers) + upload_file.read = AsyncMock(return_value=file_content) + request.form = AsyncMock(return_value={"file": upload_file}) + + mock_response = MagicMock() + mock_response.status_code = 200 + async_client = MagicMock() + async_client.request = AsyncMock(return_value=mock_response) + + await HttpPassThroughEndpointHelpers.non_streaming_http_request_handler( + request=request, + async_client=async_client, + url=httpx.URL("http://test.com"), + headers={}, + requested_query_params=None, + _parsed_body={"litellm_logging_obj": MagicMock()}, + forward_multipart=True, + ) + + async_client.request.assert_called_once() + call_args = async_client.request.call_args[1] + assert "files" in call_args + assert "json" not in call_args + assert call_args["files"]["file"][0] == "test.txt" + + @pytest.mark.asyncio async def test_pass_through_request_failure_handler(): """ @@ -1571,6 +1614,7 @@ async def test_pass_through_request_query_params_forwarding(): assert call_kwargs["requested_query_params"] == { "api-version": "2025-01-01-preview" } + assert call_kwargs.get("forward_multipart") is False # Verify the target URL is correct assert ( @@ -2090,13 +2134,12 @@ async def test_add_litellm_data_to_request_adds_headers_to_metadata(): @pytest.mark.asyncio async def test_create_pass_through_route_custom_body_url_target(): """ - Test that the URL-based endpoint_func created by create_pass_through_route - accepts a custom_body parameter and forwards it to pass_through_request, - taking precedence over the request-parsed body. + Test that programmatic callers (e.g. Bedrock proxy) can attach a JSON body via + request.state[LITELLM_PASS_THROUGH_CUSTOM_BODY_STATE_KEY]; it is forwarded to + pass_through_request and takes precedence over the request-parsed body. - This verifies the fix for issue #16999 where bedrock_proxy_route passes - custom_body=data to the endpoint function, which previously crashed with: - TypeError: endpoint_func() got an unexpected keyword argument 'custom_body' + We cannot use a `custom_body: dict` route parameter: FastAPI would treat it as + the HTTP body and reject multipart/form-data before the handler runs. """ from litellm.proxy.pass_through_endpoints.pass_through_endpoints import ( create_pass_through_route, @@ -2135,6 +2178,7 @@ async def test_create_pass_through_route_custom_body_url_target(): mock_request.url.path = unique_path mock_request.path_params = {} mock_request.query_params = QueryParams({}) + mock_request.state = SimpleNamespace() mock_user_api_key_dict = MagicMock() mock_user_api_key_dict.api_key = "test-key" @@ -2144,13 +2188,14 @@ async def test_create_pass_through_route_custom_body_url_target(): "retrievalQuery": {"text": "What is in the knowledge base?"}, } - # Call endpoint_func with custom_body — this is the call that - # used to crash with TypeError before the fix + setattr( + mock_request.state, LITELLM_PASS_THROUGH_CUSTOM_BODY_STATE_KEY, bedrock_body + ) + await endpoint_func( request=mock_request, fastapi_response=MagicMock(), user_api_key_dict=mock_user_api_key_dict, - custom_body=bedrock_body, ) mock_pass_through.assert_called_once() @@ -2206,11 +2251,12 @@ async def test_create_pass_through_route_no_custom_body_falls_back(): mock_request.url.path = unique_path mock_request.path_params = {} mock_request.query_params = QueryParams({}) + mock_request.state = SimpleNamespace() mock_user_api_key_dict = MagicMock() mock_user_api_key_dict.api_key = "test-key" - # Call without custom_body — should use the request-parsed body + # Call without state body — should use the request-parsed body await endpoint_func( request=mock_request, fastapi_response=MagicMock(), @@ -2232,11 +2278,15 @@ def test_build_full_path_with_root_default(): InitPassThroughEndpointHelpers, ) - with patch("litellm.proxy.pass_through_endpoints.pass_through_endpoints.get_server_root_path") as mock_get_root: + with patch( + "litellm.proxy.pass_through_endpoints.pass_through_endpoints.get_server_root_path" + ) as mock_get_root: # Test with default root path mock_get_root.return_value = "/" - result = InitPassThroughEndpointHelpers._build_full_path_with_root("/api/v1/endpoint") + result = InitPassThroughEndpointHelpers._build_full_path_with_root( + "/api/v1/endpoint" + ) assert result == "/api/v1/endpoint" @@ -2248,11 +2298,15 @@ def test_build_full_path_with_root_custom(): InitPassThroughEndpointHelpers, ) - with patch("litellm.proxy.pass_through_endpoints.pass_through_endpoints.get_server_root_path") as mock_get_root: + with patch( + "litellm.proxy.pass_through_endpoints.pass_through_endpoints.get_server_root_path" + ) as mock_get_root: # Test with custom root path /proxy mock_get_root.return_value = "/proxy" - result = InitPassThroughEndpointHelpers._build_full_path_with_root("/api/v1/endpoint") + result = InitPassThroughEndpointHelpers._build_full_path_with_root( + "/api/v1/endpoint" + ) assert result == "/proxy/api/v1/endpoint" @@ -2264,7 +2318,9 @@ def test_build_full_path_with_root_nested(): InitPassThroughEndpointHelpers, ) - with patch("litellm.proxy.pass_through_endpoints.pass_through_endpoints.get_server_root_path") as mock_get_root: + with patch( + "litellm.proxy.pass_through_endpoints.pass_through_endpoints.get_server_root_path" + ) as mock_get_root: # Test with nested root path /api/v2 mock_get_root.return_value = "/api/v2" @@ -2296,24 +2352,46 @@ def test_is_registered_pass_through_route_with_custom_root(): "headers": {}, } - with patch("litellm.proxy.pass_through_endpoints.pass_through_endpoints.get_server_root_path") as mock_get_root: + with patch( + "litellm.proxy.pass_through_endpoints.pass_through_endpoints.get_server_root_path" + ) as mock_get_root: # Test with custom root path /proxy mock_get_root.return_value = "/proxy" # Should match when request route includes the root path - assert InitPassThroughEndpointHelpers.is_registered_pass_through_route("/proxy/api/endpoint") is True + assert ( + InitPassThroughEndpointHelpers.is_registered_pass_through_route( + "/proxy/api/endpoint" + ) + is True + ) # Should not match when request route doesn't include root path - assert InitPassThroughEndpointHelpers.is_registered_pass_through_route("/api/endpoint") is False + assert ( + InitPassThroughEndpointHelpers.is_registered_pass_through_route( + "/api/endpoint" + ) + is False + ) # Test with default root path mock_get_root.return_value = "/" # Should match with default root - assert InitPassThroughEndpointHelpers.is_registered_pass_through_route("/api/endpoint") is True + assert ( + InitPassThroughEndpointHelpers.is_registered_pass_through_route( + "/api/endpoint" + ) + is True + ) # Should not match with root prepended when root is / - assert InitPassThroughEndpointHelpers.is_registered_pass_through_route("/proxy/api/endpoint") is False + assert ( + InitPassThroughEndpointHelpers.is_registered_pass_through_route( + "/proxy/api/endpoint" + ) + is False + ) # Clean up _registered_pass_through_routes.clear() @@ -2345,25 +2423,33 @@ def test_get_registered_pass_through_route_with_custom_root(): route_key = f"{endpoint_id}:exact:{path}" _registered_pass_through_routes[route_key] = target_config - with patch("litellm.proxy.pass_through_endpoints.pass_through_endpoints.get_server_root_path") as mock_get_root: + with patch( + "litellm.proxy.pass_through_endpoints.pass_through_endpoints.get_server_root_path" + ) as mock_get_root: # Test with custom root path /litellm mock_get_root.return_value = "/litellm" # Should return config when request route includes root path - result = InitPassThroughEndpointHelpers.get_registered_pass_through_route("/litellm/chat/completions") + result = InitPassThroughEndpointHelpers.get_registered_pass_through_route( + "/litellm/chat/completions" + ) assert result is not None assert result["target"] == "http://api.example.com/v1/chat/completions" assert result["headers"]["Authorization"] == "Bearer token123" # Should return None when route doesn't match - result = InitPassThroughEndpointHelpers.get_registered_pass_through_route("/chat/completions") + result = InitPassThroughEndpointHelpers.get_registered_pass_through_route( + "/chat/completions" + ) assert result is None # Test with default root path mock_get_root.return_value = "/" # Should return config with default root - result = InitPassThroughEndpointHelpers.get_registered_pass_through_route("/chat/completions") + result = InitPassThroughEndpointHelpers.get_registered_pass_through_route( + "/chat/completions" + ) assert result is not None assert result["target"] == "http://api.example.com/v1/chat/completions" @@ -2382,9 +2468,7 @@ def test_mapped_pass_through_routes_with_server_root_path(): InitPassThroughEndpointHelpers, ) - with patch( - "litellm.proxy.utils.get_server_root_path" - ) as mock_get_root: + with patch("litellm.proxy.utils.get_server_root_path") as mock_get_root: mock_get_root.return_value = "/litellm" # prefixed route should match mapped routes like /vertex_ai @@ -2410,7 +2494,6 @@ def test_mapped_pass_through_routes_with_server_root_path(): ) - @pytest.mark.asyncio async def test_multipart_passthrough_preserves_boundary(): """ @@ -2425,7 +2508,9 @@ async def test_multipart_passthrough_preserves_boundary(): mock_response = MagicMock() mock_response.status_code = 200 mock_response.headers = httpx.Headers({"content-type": "application/json"}) - mock_response.aread = AsyncMock(return_value=b'{"filename": "test.txt", "size": 17}') + mock_response.aread = AsyncMock( + return_value=b'{"filename": "test.txt", "size": 17}' + ) mock_response.text = '{"filename": "test.txt", "size": 17}' async def mock_httpx_request(method, url, **kwargs): @@ -2435,7 +2520,9 @@ async def mock_httpx_request(method, url, **kwargs): # Verify content-type is NOT in headers (httpx will set it with correct boundary) headers = kwargs.get("headers", {}) - assert "content-type" not in headers, "content-type should be removed for multipart" + assert ( + "content-type" not in headers + ), "content-type should be removed for multipart" filename, content, content_type = kwargs["files"]["file"] assert filename == "test.txt"