Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@
create_websocket_passthrough_route,
websocket_passthrough_request,
)
from litellm.types.passthrough_endpoints.pass_through_endpoints import (
LITELLM_PASS_THROUGH_CUSTOM_BODY_STATE_KEY,
)
from litellm.proxy.utils import is_known_model
from litellm.proxy.vector_store_endpoints.utils import (
is_allowed_to_call_vector_store_endpoint,
Expand Down Expand Up @@ -1086,11 +1089,11 @@ async def bedrock_proxy_route(
is_streaming_request=is_streaming_request,
_forward_headers=True,
) # dynamically construct pass-through endpoint based on incoming path
setattr(request.state, LITELLM_PASS_THROUGH_CUSTOM_BODY_STATE_KEY, data)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P0 NameErrorLITELLM_PASS_THROUGH_CUSTOM_BODY_STATE_KEY is not in scope

LITELLM_PASS_THROUGH_CUSTOM_BODY_STATE_KEY is defined in pass_through_endpoints.py but is not imported into this module. The top-level import block (lines 38-43) brings in HttpPassThroughEndpointHelpers, create_pass_through_route, etc., but not this constant. _get_litellm_pass_through_custom_body_state_key() was added to provide a lazy import, but it is never called here — the bare name is referenced instead. Every call to bedrock_proxy_route will raise NameError: name 'LITELLM_PASS_THROUGH_CUSTOM_BODY_STATE_KEY' is not defined.

Fix: add the constant to the existing top-level import (no circular-import concern — llm_passthrough_endpoints already imports from pass_through_endpoints at module level):

from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
    HttpPassThroughEndpointHelpers,
    LITELLM_PASS_THROUGH_CUSTOM_BODY_STATE_KEY,
    create_pass_through_route,
    create_websocket_passthrough_route,
    websocket_passthrough_request,
)

The _get_litellm_pass_through_custom_body_state_key() helper can then be removed as it becomes dead code.

received_value = await endpoint_func(
request,
fastapi_response,
user_api_key_dict,
custom_body=data, # type: ignore
)

return received_value
Expand Down
135 changes: 85 additions & 50 deletions litellm/proxy/pass_through_endpoints/pass_through_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from litellm.types.llms.custom_http import httpxSpecialProvider
from litellm.types.passthrough_endpoints.pass_through_endpoints import (
EndpointType,
LITELLM_PASS_THROUGH_CUSTOM_BODY_STATE_KEY,
PassthroughStandardLoggingPayload,
)

Expand Down Expand Up @@ -392,6 +393,7 @@ async def non_streaming_http_request_handler(
headers: dict,
requested_query_params: Optional[dict] = None,
_parsed_body: Optional[dict] = None,
forward_multipart: bool = False,
) -> httpx.Response:
"""
Handle non-streaming HTTP requests
Expand All @@ -407,10 +409,12 @@ async def non_streaming_http_request_handler(
)
elif (
HttpPassThroughEndpointHelpers.is_multipart(request) is True
and not _parsed_body
and forward_multipart
):
# Only use multipart handler if we don't have a parsed body
# (parsed body means it was JSON despite multipart content-type header)
# Forward multipart via make_multipart_http_request even when _parsed_body is
# non-empty (pass_through_request always injects litellm_logging_obj, etc.).
# forward_multipart is False when custom_body was supplied (JSON body despite
# multipart content-type) — those requests use the generic json= path.
return await HttpPassThroughEndpointHelpers.make_multipart_http_request(
request=request,
async_client=async_client,
Expand Down Expand Up @@ -449,6 +453,7 @@ async def make_multipart_http_request(
url: httpx.URL,
headers: dict,
requested_query_params: Optional[dict] = None,
stream: bool = False,
) -> httpx.Response:
"""Process multipart/form-data requests, handling both files and form fields"""
form_data = await request.form()
Expand All @@ -457,10 +462,10 @@ async def make_multipart_http_request(

for field_name, field_value in form_data.items():
if isinstance(field_value, (StarletteUploadFile, UploadFile)):
files[
field_name
] = await HttpPassThroughEndpointHelpers._build_request_files_from_upload_file(
upload_file=field_value
files[field_name] = (
await HttpPassThroughEndpointHelpers._build_request_files_from_upload_file(
upload_file=field_value
)
)
else:
form_data_dict[field_name] = field_value
Expand All @@ -470,15 +475,26 @@ async def make_multipart_http_request(
headers_copy = headers.copy()
headers_copy.pop("content-type", None)

response = await async_client.request(
# httpx.AsyncClient.request() does not accept stream=; use send() for streaming.
if stream:
req = async_client.build_request(
request.method,
url,
headers=headers_copy,
params=requested_query_params,
files=files,
data=form_data_dict,
)
return await async_client.send(req, stream=True)

return await async_client.request(
method=request.method,
url=url,
headers=headers_copy,
params=requested_query_params,
files=files,
data=form_data_dict,
)
return response

@staticmethod
def _init_kwargs_for_pass_through_endpoint(
Expand Down Expand Up @@ -537,9 +553,9 @@ def _init_kwargs_for_pass_through_endpoint(
"passthrough_logging_payload": passthrough_logging_payload,
}

logging_obj.model_call_details[
"passthrough_logging_payload"
] = passthrough_logging_payload
logging_obj.model_call_details["passthrough_logging_payload"] = (
passthrough_logging_payload
)

return kwargs

Expand Down Expand Up @@ -803,15 +819,27 @@ async def pass_through_request( # noqa: PLR0915
)

if stream:
req = async_client.build_request(
"POST",
url,
json=_parsed_body,
params=requested_query_params,
headers=headers,
)
if is_multipart:
response = (
await HttpPassThroughEndpointHelpers.make_multipart_http_request(
request=request,
async_client=async_client,
url=url,
headers=headers,
requested_query_params=requested_query_params,
stream=True,
)
)
else:
req = async_client.build_request(
"POST",
url,
json=_parsed_body,
params=requested_query_params,
headers=headers,
)

response = await async_client.send(req, stream=stream)
response = await async_client.send(req, stream=stream)

try:
response.raise_for_status()
Expand Down Expand Up @@ -845,6 +873,7 @@ async def pass_through_request( # noqa: PLR0915
headers=headers,
requested_query_params=requested_query_params,
_parsed_body=_parsed_body,
forward_multipart=is_multipart,
)
)
verbose_proxy_logger.debug("response.headers= %s", response.headers)
Expand Down Expand Up @@ -1106,9 +1135,6 @@ async def endpoint_func( # type: ignore
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
subpath: str = "", # captures sub-paths when include_subpath=True
custom_body: Optional[
dict
] = None, # accepted for signature compatibility with URL-based path; not forwarded because chat_completion_pass_through_endpoint does not support it
):
return await chat_completion_pass_through_endpoint(
fastapi_response=fastapi_response,
Expand All @@ -1125,9 +1151,6 @@ async def endpoint_func( # type: ignore
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
subpath: str = "", # captures sub-paths when include_subpath=True
custom_body: Optional[
dict
] = None, # caller-supplied body takes precedence over request-parsed body
):
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
InitPassThroughEndpointHelpers,
Expand Down Expand Up @@ -1203,28 +1226,40 @@ async def endpoint_func( # type: ignore
)
if query_params:
final_query_params.update(query_params)
# Caller-supplied custom_body takes precedence over the request-parsed body
# Programmatic callers set LITELLM_PASS_THROUGH_CUSTOM_BODY_STATE_KEY on
# request.state (see Bedrock proxy). Parsed JSON envelope otherwise.
state_custom_body: Optional[dict] = getattr(
request.state,
LITELLM_PASS_THROUGH_CUSTOM_BODY_STATE_KEY,
None,
)
final_custom_body: Optional[dict] = None
if custom_body is not None:
final_custom_body = custom_body
if isinstance(state_custom_body, dict):
final_custom_body = state_custom_body
elif isinstance(custom_body_data, dict):
final_custom_body = custom_body_data

return await pass_through_request( # type: ignore
request=request,
target=full_target,
custom_headers=headers_dict,
user_api_key_dict=user_api_key_dict,
forward_headers=cast(Optional[bool], param_forward_headers),
merge_query_params=cast(Optional[bool], param_merge_query_params),
query_params=final_query_params,
default_query_params=cast(Optional[dict], param_default_query_params),
stream=is_streaming_request or stream,
custom_body=final_custom_body,
cost_per_request=cast(Optional[float], param_cost_per_request),
custom_llm_provider=custom_llm_provider,
guardrails_config=cast(Optional[dict], param_guardrails),
)
try:
return await pass_through_request( # type: ignore
request=request,
target=full_target,
custom_headers=headers_dict,
user_api_key_dict=user_api_key_dict,
forward_headers=cast(Optional[bool], param_forward_headers),
merge_query_params=cast(Optional[bool], param_merge_query_params),
query_params=final_query_params,
default_query_params=cast(
Optional[dict], param_default_query_params
),
stream=is_streaming_request or stream,
custom_body=final_custom_body,
cost_per_request=cast(Optional[float], param_cost_per_request),
custom_llm_provider=custom_llm_provider,
guardrails_config=cast(Optional[dict], param_guardrails),
)
finally:
if hasattr(request.state, LITELLM_PASS_THROUGH_CUSTOM_BODY_STATE_KEY):
delattr(request.state, LITELLM_PASS_THROUGH_CUSTOM_BODY_STATE_KEY)

return endpoint_func

Expand Down Expand Up @@ -1468,9 +1503,9 @@ async def forward_client_to_upstream() -> None:
)
if extracted_model:
kwargs["model"] = extracted_model
kwargs[
"custom_llm_provider"
] = "vertex_ai-language-models"
kwargs["custom_llm_provider"] = (
"vertex_ai-language-models"
)
# Update logging object with correct model
logging_obj.model = extracted_model
logging_obj.model_call_details[
Expand Down Expand Up @@ -1536,9 +1571,9 @@ async def forward_upstream_to_client() -> None:
# Update logging object with correct model
logging_obj.model = extracted_model
logging_obj.model_call_details["model"] = extracted_model
logging_obj.model_call_details[
"custom_llm_provider"
] = "vertex_ai_language_models"
logging_obj.model_call_details["custom_llm_provider"] = (
"vertex_ai_language_models"
)
verbose_proxy_logger.debug(
f"WebSocket passthrough ({endpoint}): Successfully extracted model '{extracted_model}' and set provider to 'vertex_ai' from server setup response"
)
Expand Down
4 changes: 4 additions & 0 deletions litellm/types/passthrough_endpoints/pass_through_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@

from typing_extensions import TypedDict

# Request.state key for programmatic pass-through callers (e.g. Bedrock proxy) that attach
# JSON without a FastAPI `custom_body` parameter (which would consume the HTTP body).
LITELLM_PASS_THROUGH_CUSTOM_BODY_STATE_KEY = "litellm_pass_through_custom_body"


class EndpointType(str, Enum):
VERTEX_AI = "vertex-ai"
Expand Down
Loading
Loading