@@ -971,3 +971,106 @@ def test_extract_output_messages_with_mixed_text_types(self):
971971 )
972972 assert len (messages ) == 1
973973 assert messages [0 ]["content" ][0 ]["text" ] == "Part 1Part 2"
974+
975+
976+ class TestNativeWebSocketUrlConstruction :
977+ """Test that native WebSocket URLs include the model query parameter.
978+
979+ These tests mock websockets.connect so they exercise the actual URL-building
980+ code inside BaseLLMHTTPHandler.async_responses_websocket rather than
981+ reimplementing the logic themselves.
982+ """
983+
984+ @pytest .mark .asyncio
985+ async def test_openai_ws_url_includes_model (self ):
986+ """Handler must pass ?model= in the URL to the backend WebSocket."""
987+ from unittest .mock import AsyncMock , MagicMock , patch
988+
989+ captured_urls = []
990+
991+ class FakeConnect :
992+ def __init__ (self , url , ** kwargs ):
993+ captured_urls .append (url )
994+
995+ async def __aenter__ (self ):
996+ raise Exception ("stop" )
997+
998+ async def __aexit__ (self , * args ):
999+ pass
1000+
1001+ mock_config = MagicMock (spec = OpenAIResponsesAPIConfig )
1002+ mock_config .supports_native_websocket .return_value = True
1003+ mock_config .get_complete_url .return_value = "https://api.openai.com/v1/responses"
1004+ mock_config .validate_environment .return_value = {}
1005+
1006+ mock_logging = MagicMock ()
1007+ mock_logging .pre_call = MagicMock ()
1008+
1009+ from litellm .llms .custom_httpx .llm_http_handler import BaseLLMHTTPHandler
1010+
1011+ handler = BaseLLMHTTPHandler ()
1012+
1013+ mock_ws = MagicMock ()
1014+ mock_ws .close = AsyncMock ()
1015+
1016+ with patch ("websockets.connect" , FakeConnect ):
1017+ await handler .async_responses_websocket (
1018+ model = "gpt-4o-mini" ,
1019+ websocket = mock_ws ,
1020+ logging_obj = mock_logging ,
1021+ responses_api_provider_config = mock_config ,
1022+ api_key = "sk-test" ,
1023+ )
1024+
1025+ assert len (captured_urls ) == 1
1026+ from urllib .parse import parse_qs , urlparse
1027+ qs = parse_qs (urlparse (captured_urls [0 ]).query )
1028+ assert qs .get ("model" ) == ["gpt-4o-mini" ], f"Expected model in URL, got: { captured_urls [0 ]} "
1029+
1030+ @pytest .mark .asyncio
1031+ async def test_ws_url_preserves_existing_params_and_adds_model (self ):
1032+ """When api_base already has query params, model is added alongside them."""
1033+ from unittest .mock import AsyncMock , MagicMock , patch
1034+
1035+ captured_urls = []
1036+
1037+ class FakeConnect :
1038+ def __init__ (self , url , ** kwargs ):
1039+ captured_urls .append (url )
1040+
1041+ async def __aenter__ (self ):
1042+ raise Exception ("stop" )
1043+
1044+ async def __aexit__ (self , * args ):
1045+ pass
1046+
1047+ mock_config = MagicMock (spec = OpenAIResponsesAPIConfig )
1048+ mock_config .supports_native_websocket .return_value = True
1049+ mock_config .get_complete_url .return_value = (
1050+ "https://custom.example.com/v1/responses?api-version=2024-05-01"
1051+ )
1052+ mock_config .validate_environment .return_value = {}
1053+
1054+ mock_logging = MagicMock ()
1055+ mock_logging .pre_call = MagicMock ()
1056+
1057+ from litellm .llms .custom_httpx .llm_http_handler import BaseLLMHTTPHandler
1058+
1059+ handler = BaseLLMHTTPHandler ()
1060+ mock_ws = MagicMock ()
1061+ mock_ws .close = AsyncMock ()
1062+
1063+ with patch ("websockets.connect" , FakeConnect ):
1064+ await handler .async_responses_websocket (
1065+ model = "gpt-4o" ,
1066+ websocket = mock_ws ,
1067+ logging_obj = mock_logging ,
1068+ responses_api_provider_config = mock_config ,
1069+ api_key = "sk-test" ,
1070+ )
1071+
1072+ assert len (captured_urls ) == 1
1073+ from urllib .parse import parse_qs , urlparse
1074+ qs = parse_qs (urlparse (captured_urls [0 ]).query )
1075+ assert qs .get ("model" ) == ["gpt-4o" ], f"model missing from URL: { captured_urls [0 ]} "
1076+ assert qs .get ("api-version" ) == ["2024-05-01" ], f"existing param lost: { captured_urls [0 ]} "
0 commit comments