Skip to content

Commit c7f610c

Browse files
committed
Merge remote-tracking branch 'origin/main' into litellm_/compassionate-shannon
2 parents 92dbd2c + 42e5583 commit c7f610c

3 files changed

Lines changed: 116 additions & 0 deletions

File tree

docs/my-website/docs/proxy/config_settings.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -602,6 +602,8 @@ router_settings:
602602
| MCP_OAUTH2_TOKEN_CACHE_MAX_SIZE | Maximum number of entries in MCP OAuth2 token cache. Default is 200
603603
| MCP_OAUTH2_TOKEN_CACHE_MIN_TTL | Minimum TTL in seconds for MCP OAuth2 token cache. Default is 10
604604
| MCP_OAUTH2_TOKEN_EXPIRY_BUFFER_SECONDS | Seconds to subtract from token expiry when computing cache TTL. Default is 60
605+
| MCP_PER_USER_TOKEN_DEFAULT_TTL | Default TTL in seconds for per-user MCP OAuth tokens stored in Redis. Default is 43200 (12 hours)
606+
| MCP_PER_USER_TOKEN_EXPIRY_BUFFER_SECONDS | Seconds to subtract from per-user MCP OAuth token expiry when computing Redis TTL. Default is 60
605607
| DEFAULT_MOCK_RESPONSE_COMPLETION_TOKEN_COUNT | Default token count for mock response completions. Default is 20
606608
| DEFAULT_MOCK_RESPONSE_PROMPT_TOKEN_COUNT | Default token count for mock response prompts. Default is 10
607609
| DEFAULT_MODEL_CREATED_AT_TIME | Default creation timestamp for models. Default is 1677610602

litellm/llms/custom_httpx/llm_http_handler.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import json
22
import ssl
3+
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
34
from typing import (
45
TYPE_CHECKING,
56
Any,
@@ -5027,6 +5028,16 @@ async def async_responses_websocket(
50275028
litellm_params={},
50285029
)
50295030
ws_url = http_url.replace("https://", "wss://").replace("http://", "ws://")
5031+
# OpenAI's WebSocket responses endpoint requires ?model= in the URL,
5032+
# matching the Realtime API convention (wss://.../v1/realtime?model=...).
5033+
# Use urllib.parse so existing query params (e.g. api-version) are preserved.
5034+
_parsed = urlparse(ws_url)
5035+
_qs = parse_qs(_parsed.query)
5036+
if "model" not in _qs:
5037+
_qs["model"] = [model]
5038+
ws_url = urlunparse(
5039+
_parsed._replace(query=urlencode({k: v[0] for k, v in _qs.items()}))
5040+
)
50305041

50315042
try:
50325043
ssl_context = get_shared_realtime_ssl_context()

tests/test_litellm/responses/test_responses_websocket_all_providers.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -971,3 +971,106 @@ def test_extract_output_messages_with_mixed_text_types(self):
971971
)
972972
assert len(messages) == 1
973973
assert messages[0]["content"][0]["text"] == "Part 1Part 2"
974+
975+
976+
class TestNativeWebSocketUrlConstruction:
977+
"""Test that native WebSocket URLs include the model query parameter.
978+
979+
These tests mock websockets.connect so they exercise the actual URL-building
980+
code inside BaseLLMHTTPHandler.async_responses_websocket rather than
981+
reimplementing the logic themselves.
982+
"""
983+
984+
@pytest.mark.asyncio
985+
async def test_openai_ws_url_includes_model(self):
986+
"""Handler must pass ?model= in the URL to the backend WebSocket."""
987+
from unittest.mock import AsyncMock, MagicMock, patch
988+
989+
captured_urls = []
990+
991+
class FakeConnect:
992+
def __init__(self, url, **kwargs):
993+
captured_urls.append(url)
994+
995+
async def __aenter__(self):
996+
raise Exception("stop")
997+
998+
async def __aexit__(self, *args):
999+
pass
1000+
1001+
mock_config = MagicMock(spec=OpenAIResponsesAPIConfig)
1002+
mock_config.supports_native_websocket.return_value = True
1003+
mock_config.get_complete_url.return_value = "https://api.openai.com/v1/responses"
1004+
mock_config.validate_environment.return_value = {}
1005+
1006+
mock_logging = MagicMock()
1007+
mock_logging.pre_call = MagicMock()
1008+
1009+
from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
1010+
1011+
handler = BaseLLMHTTPHandler()
1012+
1013+
mock_ws = MagicMock()
1014+
mock_ws.close = AsyncMock()
1015+
1016+
with patch("websockets.connect", FakeConnect):
1017+
await handler.async_responses_websocket(
1018+
model="gpt-4o-mini",
1019+
websocket=mock_ws,
1020+
logging_obj=mock_logging,
1021+
responses_api_provider_config=mock_config,
1022+
api_key="sk-test",
1023+
)
1024+
1025+
assert len(captured_urls) == 1
1026+
from urllib.parse import parse_qs, urlparse
1027+
qs = parse_qs(urlparse(captured_urls[0]).query)
1028+
assert qs.get("model") == ["gpt-4o-mini"], f"Expected model in URL, got: {captured_urls[0]}"
1029+
1030+
@pytest.mark.asyncio
1031+
async def test_ws_url_preserves_existing_params_and_adds_model(self):
1032+
"""When api_base already has query params, model is added alongside them."""
1033+
from unittest.mock import AsyncMock, MagicMock, patch
1034+
1035+
captured_urls = []
1036+
1037+
class FakeConnect:
1038+
def __init__(self, url, **kwargs):
1039+
captured_urls.append(url)
1040+
1041+
async def __aenter__(self):
1042+
raise Exception("stop")
1043+
1044+
async def __aexit__(self, *args):
1045+
pass
1046+
1047+
mock_config = MagicMock(spec=OpenAIResponsesAPIConfig)
1048+
mock_config.supports_native_websocket.return_value = True
1049+
mock_config.get_complete_url.return_value = (
1050+
"https://custom.example.com/v1/responses?api-version=2024-05-01"
1051+
)
1052+
mock_config.validate_environment.return_value = {}
1053+
1054+
mock_logging = MagicMock()
1055+
mock_logging.pre_call = MagicMock()
1056+
1057+
from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
1058+
1059+
handler = BaseLLMHTTPHandler()
1060+
mock_ws = MagicMock()
1061+
mock_ws.close = AsyncMock()
1062+
1063+
with patch("websockets.connect", FakeConnect):
1064+
await handler.async_responses_websocket(
1065+
model="gpt-4o",
1066+
websocket=mock_ws,
1067+
logging_obj=mock_logging,
1068+
responses_api_provider_config=mock_config,
1069+
api_key="sk-test",
1070+
)
1071+
1072+
assert len(captured_urls) == 1
1073+
from urllib.parse import parse_qs, urlparse
1074+
qs = parse_qs(urlparse(captured_urls[0]).query)
1075+
assert qs.get("model") == ["gpt-4o"], f"model missing from URL: {captured_urls[0]}"
1076+
assert qs.get("api-version") == ["2024-05-01"], f"existing param lost: {captured_urls[0]}"

0 commit comments

Comments
 (0)