diff --git a/sgl-router/benches/tool_parser_benchmark.rs b/sgl-router/benches/tool_parser_benchmark.rs index 35bd7f8209a0..d3dddc93066f 100644 --- a/sgl-router/benches/tool_parser_benchmark.rs +++ b/sgl-router/benches/tool_parser_benchmark.rs @@ -409,7 +409,7 @@ fn bench_concurrent_parsing(c: &mut Criterion) { let result = rt.block_on(async { parser.parse_complete(input).await }); - if let Ok(tools) = result { + if let Ok((_normal_text, tools)) = result { total_p.fetch_add(tools.len() as u64, Ordering::Relaxed); } } diff --git a/sgl-router/src/reasoning_parser/traits.rs b/sgl-router/src/reasoning_parser/traits.rs index 160fa51d92da..d435e6c2357d 100644 --- a/sgl-router/src/reasoning_parser/traits.rs +++ b/sgl-router/src/reasoning_parser/traits.rs @@ -3,7 +3,7 @@ use std::fmt; /// Result of parsing text for reasoning content. #[derive(Debug, Clone, Default, PartialEq)] pub struct ParserResult { - /// The normal text outside of reasoning blocks. + /// The normal text outside reasoning blocks. pub normal_text: String, /// The extracted reasoning text from within reasoning blocks. diff --git a/sgl-router/src/routers/grpc/router.rs b/sgl-router/src/routers/grpc/router.rs index cc159be66b80..40dd99088cc3 100644 --- a/sgl-router/src/routers/grpc/router.rs +++ b/sgl-router/src/routers/grpc/router.rs @@ -804,7 +804,7 @@ impl GrpcRouter { .get_parser(&original_request.model) { match parser.parse_complete(&processed_text).await { - Ok(parsed_tool_calls) => { + Ok((normal_text, parsed_tool_calls)) => { if !parsed_tool_calls.is_empty() { let spec_tool_calls = parsed_tool_calls .into_iter() @@ -821,7 +821,7 @@ impl GrpcRouter { }) .collect(); tool_calls = Some(spec_tool_calls); - processed_text = String::new(); + processed_text = normal_text; } } Err(e) => { diff --git a/sgl-router/src/tool_parser/parsers/deepseek_parser.rs b/sgl-router/src/tool_parser/parsers/deepseek_parser.rs index 5e467bf2b8fa..5ab652da103f 100644 --- a/sgl-router/src/tool_parser/parsers/deepseek_parser.rs +++ b/sgl-router/src/tool_parser/parsers/deepseek_parser.rs @@ -50,14 +50,6 @@ impl DeepSeekParser { text.contains("<|tool▁calls▁begin|>") } - /// Extract all tool call blocks from text - fn extract_tool_calls<'a>(&self, text: &'a str) -> Vec<&'a str> { - self.tool_call_extractor - .find_iter(text) - .map(|m| m.as_str()) - .collect() - } - /// Parse a single tool call block fn parse_tool_call(&self, block: &str) -> ToolParserResult> { if let Some(captures) = self.func_detail_extractor.captures(block) { @@ -115,23 +107,42 @@ impl Default for DeepSeekParser { #[async_trait] impl ToolParser for DeepSeekParser { - async fn parse_complete(&self, text: &str) -> ToolParserResult> { + async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec)> { // Check if text contains DeepSeek format if !self.has_tool_markers(text) { - return Ok(vec![]); + return Ok((text.to_string(), vec![])); } - // Extract all tool call blocks - let tool_blocks = self.extract_tool_calls(text); + // Collect matches with positions and parse tools in one pass + let matches: Vec<_> = self.tool_call_extractor.find_iter(text).collect(); let mut tools = Vec::new(); - for block in tool_blocks { - if let Some(tool) = self.parse_tool_call(block)? { + for mat in matches.iter() { + if let Some(tool) = self.parse_tool_call(mat.as_str())? { tools.push(tool); } } - Ok(tools) + // Extract normal text using first and last match positions + let normal_text = if tools.is_empty() || matches.is_empty() { + text.to_string() + } else { + let first_start = matches[0].start(); + let last_end = matches.last().unwrap().end(); + let before = if first_start > 0 { + &text[..first_start] + } else { + "" + }; + let after = if last_end < text.len() { + &text[last_end..] + } else { + "" + }; + format!("{}{}", before, after) + }; + + Ok((normal_text, tools)) } async fn parse_incremental( @@ -241,10 +252,10 @@ mod tests { {"location": "Tokyo", "units": "celsius"} ```<|tool▁call▁end|><|tool▁calls▁end|>More text"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "get_weather"); - assert!(result[0].function.arguments.contains("Tokyo")); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "get_weather"); + assert!(tools[0].function.arguments.contains("Tokyo")); } #[tokio::test] @@ -259,12 +270,12 @@ mod tests { {"location": "Paris"} ```<|tool▁call▁end|><|tool▁calls▁end|>"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 2); - assert_eq!(result[0].function.name, "get_weather"); - assert_eq!(result[1].function.name, "get_weather"); - assert!(result[0].function.arguments.contains("Tokyo")); - assert!(result[1].function.arguments.contains("Paris")); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 2); + assert_eq!(tools[0].function.name, "get_weather"); + assert_eq!(tools[1].function.name, "get_weather"); + assert!(tools[0].function.arguments.contains("Tokyo")); + assert!(tools[1].function.arguments.contains("Paris")); } #[test] diff --git a/sgl-router/src/tool_parser/parsers/glm4_moe_parser.rs b/sgl-router/src/tool_parser/parsers/glm4_moe_parser.rs index 017de1256499..0d374ea08afb 100644 --- a/sgl-router/src/tool_parser/parsers/glm4_moe_parser.rs +++ b/sgl-router/src/tool_parser/parsers/glm4_moe_parser.rs @@ -130,21 +130,42 @@ impl Default for Glm4MoeParser { #[async_trait] impl ToolParser for Glm4MoeParser { - async fn parse_complete(&self, text: &str) -> ToolParserResult> { + async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec)> { // Check if text contains GLM-4 MoE format if !self.has_tool_markers(text) { - return Ok(vec![]); + return Ok((text.to_string(), vec![])); } - // Extract all tool call blocks + // Collect matches with positions and parse tools in one pass + let matches: Vec<_> = self.tool_call_extractor.find_iter(text).collect(); let mut tools = Vec::new(); - for mat in self.tool_call_extractor.find_iter(text) { + + for mat in matches.iter() { if let Some(tool) = self.parse_tool_call(mat.as_str())? { tools.push(tool); } } - Ok(tools) + // Extract normal text using first and last match positions + let normal_text = if tools.is_empty() { + text.to_string() + } else { + let first_start = matches[0].start(); + let last_end = matches.last().unwrap().end(); + let before = if first_start > 0 { + &text[..first_start] + } else { + "" + }; + let after = if last_end < text.len() { + &text[last_end..] + } else { + "" + }; + format!("{}{}", before, after) + }; + + Ok((normal_text, tools)) } async fn parse_incremental( @@ -232,11 +253,12 @@ mod tests { 2024-06-27 More text"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "get_weather"); - assert!(result[0].function.arguments.contains("Beijing")); - assert!(result[0].function.arguments.contains("2024-06-27")); + let (normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "get_weather"); + assert!(tools[0].function.arguments.contains("Beijing")); + assert!(tools[0].function.arguments.contains("2024-06-27")); + assert_eq!(normal_text, "Some text\nMore text"); // Text before and after tool call } #[tokio::test] @@ -251,12 +273,13 @@ mod tests { Shanghai "#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 2); - assert_eq!(result[0].function.name, "get_weather"); - assert_eq!(result[1].function.name, "get_weather"); - assert!(result[0].function.arguments.contains("Beijing")); - assert!(result[1].function.arguments.contains("Shanghai")); + let (normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 2); + assert_eq!(tools[0].function.name, "get_weather"); + assert_eq!(tools[1].function.name, "get_weather"); + assert!(tools[0].function.arguments.contains("Beijing")); + assert!(tools[1].function.arguments.contains("Shanghai")); + assert_eq!(normal_text, ""); // Pure tool calls, no normal text } #[tokio::test] @@ -271,12 +294,13 @@ mod tests { test "#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "process_data"); + let (normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(normal_text, ""); // Pure tool call, no normal text + assert_eq!(tools[0].function.name, "process_data"); // Parse arguments to check types - let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert_eq!(args["count"], 42); assert_eq!(args["active"], true); assert_eq!(args["name"], "test"); diff --git a/sgl-router/src/tool_parser/parsers/gpt_oss_parser.rs b/sgl-router/src/tool_parser/parsers/gpt_oss_parser.rs index 646161a72bb5..3f5343859c6a 100644 --- a/sgl-router/src/tool_parser/parsers/gpt_oss_parser.rs +++ b/sgl-router/src/tool_parser/parsers/gpt_oss_parser.rs @@ -71,10 +71,10 @@ impl Default for GptOssParser { #[async_trait] impl ToolParser for GptOssParser { - async fn parse_complete(&self, text: &str) -> ToolParserResult> { + async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec)> { // Check if text contains GPT-OSS format if !self.has_tool_markers(text) { - return Ok(vec![]); + return Ok((text.to_string(), vec![])); } let mut tools = Vec::new(); @@ -119,7 +119,7 @@ impl ToolParser for GptOssParser { } } - Ok(tools) + Ok((String::new(), tools)) // GPT-OSS parser returns empty normal text } async fn parse_incremental( @@ -239,10 +239,10 @@ mod tests { <|channel|>commentary to=functions.get_weather<|constrain|>json<|message|>{"location": "San Francisco"}<|call|> More text"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "get_weather"); - assert!(result[0].function.arguments.contains("San Francisco")); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "get_weather"); + assert!(tools[0].function.arguments.contains("San Francisco")); } #[tokio::test] @@ -251,12 +251,12 @@ More text"#; let input = r#"<|channel|>commentary to=functions.get_weather<|constrain|>json<|message|>{"location": "Paris"}<|call|>commentary <|channel|>commentary to=functions.search<|constrain|>json<|message|>{"query": "Paris tourism"}<|call|>"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 2); - assert_eq!(result[0].function.name, "get_weather"); - assert_eq!(result[1].function.name, "search"); - assert!(result[0].function.arguments.contains("Paris")); - assert!(result[1].function.arguments.contains("Paris tourism")); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 2); + assert_eq!(tools[0].function.name, "get_weather"); + assert_eq!(tools[1].function.name, "search"); + assert!(tools[0].function.arguments.contains("Paris")); + assert!(tools[1].function.arguments.contains("Paris tourism")); } #[tokio::test] @@ -264,9 +264,9 @@ More text"#; let parser = GptOssParser::new(); let input = r#"<|start|>assistant<|channel|>commentary to=functions.test<|constrain|>json<|message|>{"key": "value"}<|call|>"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "test"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "test"); } #[tokio::test] @@ -275,10 +275,10 @@ More text"#; let input = r#"<|channel|>commentary to=functions.get_time<|constrain|>json<|message|>{}<|call|>"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "get_time"); - assert_eq!(result[0].function.arguments, "{}"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "get_time"); + assert_eq!(tools[0].function.arguments, "{}"); } #[test] diff --git a/sgl-router/src/tool_parser/parsers/json_parser.rs b/sgl-router/src/tool_parser/parsers/json_parser.rs index b8430dc9e540..82a41cf4fa4f 100644 --- a/sgl-router/src/tool_parser/parsers/json_parser.rs +++ b/sgl-router/src/tool_parser/parsers/json_parser.rs @@ -88,64 +88,65 @@ impl JsonParser { content.trim() } - /// Try to extract a JSON object or array from text that may contain other content - fn extract_json_from_text(&self, text: &str) -> Option { - // Look for JSON object starting with { - if let Some(start) = text.find('{') { - let mut depth = 0; - let mut in_string = false; - let mut escape_next = false; - - for (i, ch) in text[start..].char_indices() { - if escape_next { - escape_next = false; - continue; - } + /// Try to extract a first valid JSON object or array from text that may contain other content + /// Returns (json_string, normal_text) where normal_text is text before and after the JSON + fn extract_json_from_text(&self, text: &str) -> Option<(String, String)> { + let mut in_string = false; + let mut escape = false; + let mut stack: Vec = Vec::with_capacity(8); + let mut start: Option = None; + + for (i, ch) in text.char_indices() { + if escape { + escape = false; + continue; + } - match ch { - '\\' if in_string => escape_next = true, - '"' if !in_string => in_string = true, - '"' if in_string => in_string = false, - '{' if !in_string => depth += 1, - '}' if !in_string => { - depth -= 1; - if depth == 0 { - return Some(text[start..start + i + 1].to_string()); - } + match ch { + '\\' if in_string => escape = true, + '"' => in_string = !in_string, + _ if in_string => {} + '{' | '[' => { + if start.is_none() { + start = Some(i); } - _ => {} + stack.push(ch); } - } - } - - // Look for JSON array starting with [ - if let Some(start) = text.find('[') { - let mut depth = 0; - let mut in_string = false; - let mut escape_next = false; + '}' | ']' => { + let Some(open) = stack.pop() else { + // Stray closer - reset and continue looking for next valid JSON + start = None; + continue; + }; - for (i, ch) in text[start..].char_indices() { - if escape_next { - escape_next = false; - continue; - } + let valid = (open == '{' && ch == '}') || (open == '[' && ch == ']'); + if !valid { + // Mismatch - reset and continue looking + start = None; + stack.clear(); + continue; + } - match ch { - '\\' if in_string => escape_next = true, - '"' if !in_string => in_string = true, - '"' if in_string => in_string = false, - '[' if !in_string => depth += 1, - ']' if !in_string => { - depth -= 1; - if depth == 0 { - return Some(text[start..start + i + 1].to_string()); + if stack.is_empty() { + let s = start.unwrap(); + let e = i + ch.len_utf8(); + let potential_json = &text[s..e]; + + // Validate that this is actually valid JSON before returning + if serde_json::from_str::(potential_json).is_ok() { + let json = potential_json.to_string(); + let normal = format!("{}{}", &text[..s], &text[e..]); + return Some((json, normal)); + } else { + // Not valid JSON, reset and continue looking + start = None; + continue; } } - _ => {} } + _ => {} } } - None } @@ -241,16 +242,20 @@ impl Default for JsonParser { #[async_trait] impl ToolParser for JsonParser { - async fn parse_complete(&self, text: &str) -> ToolParserResult> { + async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec)> { // Check if we have multiple start tokens (e.g., multiple <|python_tag|> markers) if !self.token_config.start_tokens.is_empty() { let start_token = &self.token_config.start_tokens[0]; if !start_token.is_empty() && text.matches(start_token).count() > 1 { // We have multiple occurrences of the start token let mut all_tools = Vec::new(); + let mut all_normal_text = String::new(); let mut remaining = text; while let Some(start_pos) = remaining.find(start_token.as_str()) { + // Add text before this start token to normal text + all_normal_text.push_str(&remaining[..start_pos]); + // Extract content after this start token let after_token = &remaining[start_pos + start_token.len()..]; @@ -264,12 +269,19 @@ impl ToolParser for JsonParser { let json_content = &after_token[..end_pos]; // Try to extract and parse JSON from this segment - if let Some(extracted) = self.extract_json_from_text(json_content) { + if let Some((extracted, segment_normal_text)) = + self.extract_json_from_text(json_content) + { if let Ok(value) = serde_json::from_str::(&extracted) { if let Ok(tools) = self.parse_json_value(&value) { all_tools.extend(tools); } } + // Add the normal text from this segment + all_normal_text.push_str(&segment_normal_text); + } else { + // If no JSON found, add the entire content as normal text + all_normal_text.push_str(json_content); } // Move to the next segment @@ -279,9 +291,10 @@ impl ToolParser for JsonParser { } } - if !all_tools.is_empty() { - return Ok(all_tools); - } + // Add any remaining text + all_normal_text.push_str(remaining); + + return Ok((all_normal_text, all_tools)); } } @@ -290,21 +303,30 @@ impl ToolParser for JsonParser { // Try to parse as JSON first match serde_json::from_str::(json_content) { - Ok(value) => self.parse_json_value(&value), + Ok(value) => { + let tools = self.parse_json_value(&value)?; + Ok((String::new(), tools)) + } Err(_) => { // If parse failed, check if we have multiple JSON objects separated by the configured separator - // This handles cases like: {"name": "func1", ...};{"name": "func2", ...} + // Only do this if we can reasonably expect multiple complete JSON objects + // (i.e., text starts and ends with JSON-like structure) if !self.token_config.separator.is_empty() && json_content.contains(&self.token_config.separator) + && json_content.trim().starts_with('{') + && json_content.trim().ends_with('}') { let mut all_tools = Vec::new(); // Split by separator and try to parse each part let parts: Vec<&str> = json_content.split(&self.token_config.separator).collect(); + let mut normal_parts = Vec::new(); + for part in parts { let trimmed = part.trim(); if trimmed.is_empty() { + normal_parts.push(trimmed.to_string()); continue; } @@ -313,32 +335,40 @@ impl ToolParser for JsonParser { if let Ok(tools) = self.parse_json_value(&value) { all_tools.extend(tools); } - } else if let Some(extracted) = self.extract_json_from_text(trimmed) { + normal_parts.push(trimmed.to_string()); + } else if let Some((extracted, part_normal_text)) = + self.extract_json_from_text(trimmed) + { // Try extracting JSON from this part if let Ok(value) = serde_json::from_str::(&extracted) { if let Ok(tools) = self.parse_json_value(&value) { all_tools.extend(tools); } } + normal_parts.push(part_normal_text); + } else { + normal_parts.push(trimmed.to_string()); } } - if !all_tools.is_empty() { - return Ok(all_tools); - } + // Rejoin with the original separator to preserve it + let all_normal_text = normal_parts.join(&self.token_config.separator); + + return Ok((all_normal_text, all_tools)); } - // If no wrapper tokens configured and parse failed, - // try to extract JSON from mixed text + // If no wrapper tokens configured and parse failed, try to extract JSON from mixed text if self.token_config.start_tokens.is_empty() { - if let Some(extracted) = self.extract_json_from_text(text) { - if let Ok(value) = serde_json::from_str::(&extracted) { - return self.parse_json_value(&value); + if let Some((extracted_json, normal_text)) = self.extract_json_from_text(text) { + if let Ok(value) = serde_json::from_str::(&extracted_json) { + let tools = self.parse_json_value(&value)?; + return Ok((normal_text, tools)); } } } - // Not valid JSON, return empty - Ok(vec![]) + + // No valid JSON found, return original text as normal text + Ok((text.to_string(), vec![])) } } } @@ -538,9 +568,41 @@ mod tests { let parser = JsonParser::new(); let input = r#"{"name": "get_weather", "arguments": {"location": "San Francisco"}}"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "get_weather"); + let (normal_text, tool_calls) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tool_calls.len(), 1); + assert_eq!(tool_calls[0].function.name, "get_weather"); + assert_eq!(normal_text, ""); // Pure JSON should have no normal text + } + + #[tokio::test] + async fn test_extract_json_with_normal_text() { + let parser = JsonParser::new(); + + // Test extraction of JSON from mixed text + let input = + r#"Here is some text before {"name": "test", "arguments": {}} and some text after."#; + let (normal_text, tool_calls) = parser.parse_complete(input).await.unwrap(); + + assert_eq!(tool_calls.len(), 1); + assert_eq!(tool_calls[0].function.name, "test"); + assert_eq!( + normal_text, + "Here is some text before and some text after." + ); + } + + #[tokio::test] + async fn test_extract_json_array_with_normal_text() { + let parser = JsonParser::new(); + + // Test extraction of JSON array from mixed text + let input = r#"Prefix text [{"name": "func1", "arguments": {}}, {"name": "func2", "arguments": {}}] suffix text"#; + let (normal_text, tool_calls) = parser.parse_complete(input).await.unwrap(); + + assert_eq!(tool_calls.len(), 2); + assert_eq!(tool_calls[0].function.name, "func1"); + assert_eq!(tool_calls[1].function.name, "func2"); + assert_eq!(normal_text, "Prefix text suffix text"); } #[tokio::test] @@ -551,10 +613,11 @@ mod tests { {"name": "search", "arguments": {"query": "news"}} ]"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 2); - assert_eq!(result[0].function.name, "get_weather"); - assert_eq!(result[1].function.name, "search"); + let (normal_text, tool_calls) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tool_calls.len(), 2); + assert_eq!(tool_calls[0].function.name, "get_weather"); + assert_eq!(tool_calls[1].function.name, "search"); + assert_eq!(normal_text, ""); // Pure JSON should have no normal text } #[tokio::test] @@ -562,10 +625,11 @@ mod tests { let parser = JsonParser::new(); let input = r#"{"name": "calculate", "parameters": {"x": 10, "y": 20}}"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "calculate"); - assert!(result[0].function.arguments.contains("10")); + let (normal_text, tool_calls) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tool_calls.len(), 1); + assert_eq!(tool_calls[0].function.name, "calculate"); + assert!(tool_calls[0].function.arguments.contains("10")); + assert_eq!(normal_text, ""); // Pure JSON should have no normal text } #[tokio::test] @@ -577,9 +641,38 @@ mod tests { }); let input = r#"{"name": "test", "arguments": {}}"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "test"); + let (normal_text, tool_calls) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tool_calls.len(), 1); + assert_eq!(tool_calls[0].function.name, "test"); + assert_eq!(normal_text, ""); // Wrapper tokens with no extra text + } + + #[tokio::test] + async fn test_parse_with_start_token_invalid_json() { + let parser = JsonParser::with_config(TokenConfig { + start_tokens: vec!["<|python_tag|>".to_string()], + end_tokens: vec!["".to_string()], + separator: ";".to_string(), + }); + + let input = r#"Hello world <|python_tag|>this is not valid json at all"#; + let (normal_text, tool_calls) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tool_calls.len(), 0); + assert_eq!(normal_text, input); // Should return entire original text when JSON parsing fails + } + + #[tokio::test] + async fn test_parse_with_normal_text() { + let parser = JsonParser::new(); + let input = r#"Here is the weather data: {"name": "get_weather", "arguments": {"location": "SF"}} Let me know if you need more info."#; + + let (normal_text, tool_calls) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tool_calls.len(), 1); + assert_eq!(tool_calls[0].function.name, "get_weather"); + assert_eq!( + normal_text, + "Here is the weather data: Let me know if you need more info." + ); // Normal text is now extracted when JSON is found in mixed content } #[test] diff --git a/sgl-router/src/tool_parser/parsers/kimik2_parser.rs b/sgl-router/src/tool_parser/parsers/kimik2_parser.rs index 52f92bd909a8..223c40c36d86 100644 --- a/sgl-router/src/tool_parser/parsers/kimik2_parser.rs +++ b/sgl-router/src/tool_parser/parsers/kimik2_parser.rs @@ -79,16 +79,18 @@ impl Default for KimiK2Parser { #[async_trait] impl ToolParser for KimiK2Parser { - async fn parse_complete(&self, text: &str) -> ToolParserResult> { + async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec)> { // Check if text contains Kimi K2 format if !self.has_tool_markers(text) { - return Ok(vec![]); + return Ok((text.to_string(), vec![])); } + // Collect matches with positions and parse tools in one pass + let matches: Vec<_> = self.tool_call_extractor.captures_iter(text).collect(); let mut tools = Vec::new(); - // Extract all tool calls - for captures in self.tool_call_extractor.captures_iter(text) { + // Extract all tool calls using collected matches + for captures in matches.iter() { if let (Some(id_match), Some(args_match)) = ( captures.name("tool_call_id"), captures.name("function_arguments"), @@ -116,7 +118,26 @@ impl ToolParser for KimiK2Parser { } } - Ok(tools) + // Extract normal text using first and last match positions + let normal_text = if tools.is_empty() || matches.is_empty() { + text.to_string() + } else { + let first_start = matches[0].get(0).unwrap().start(); + let last_end = matches.last().unwrap().get(0).unwrap().end(); + let before = if first_start > 0 { + &text[..first_start] + } else { + "" + }; + let after = if last_end < text.len() { + &text[last_end..] + } else { + "" + }; + format!("{}{}", before, after) + }; + + Ok((normal_text, tools)) } async fn parse_incremental( @@ -227,10 +248,10 @@ mod tests { <|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{"location": "Tokyo", "units": "celsius"}<|tool_call_end|> <|tool_calls_section_end|>More text"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "get_weather"); - assert!(result[0].function.arguments.contains("Tokyo")); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "get_weather"); + assert!(tools[0].function.arguments.contains("Tokyo")); } #[tokio::test] @@ -241,10 +262,10 @@ mod tests { <|tool_call_begin|>functions.calculate:1<|tool_call_argument_begin|>{"expression": "2+2"}<|tool_call_end|> <|tool_calls_section_end|>"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 2); - assert_eq!(result[0].function.name, "search"); - assert_eq!(result[1].function.name, "calculate"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 2); + assert_eq!(tools[0].function.name, "search"); + assert_eq!(tools[1].function.name, "calculate"); } #[tokio::test] @@ -254,9 +275,9 @@ mod tests { <|tool_call_begin|> functions.test:0 <|tool_call_argument_begin|> {"key": "value"} <|tool_call_end|> <|tool_calls_section_end|>"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "test"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "test"); } #[test] diff --git a/sgl-router/src/tool_parser/parsers/llama_parser.rs b/sgl-router/src/tool_parser/parsers/llama_parser.rs index f729f9f7453c..4e3fc28337d3 100644 --- a/sgl-router/src/tool_parser/parsers/llama_parser.rs +++ b/sgl-router/src/tool_parser/parsers/llama_parser.rs @@ -42,22 +42,32 @@ impl Default for LlamaParser { #[async_trait] impl ToolParser for LlamaParser { - async fn parse_complete(&self, text: &str) -> ToolParserResult> { + async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec)> { // First try with the configured python_tag parser - let result = self.json_parser.parse_complete(text).await?; - - if !result.is_empty() { - return Ok(result); + let (_json_normal_text, tools) = self.json_parser.parse_complete(text).await?; + + if !tools.is_empty() { + // Extract normal text before the python tag + // JsonParser doesn't preserve normal text for single start tokens, so we do it manually + let normal_text = if let Some(tag_pos) = text.find("<|python_tag|>") { + text[..tag_pos].to_string() + } else { + String::new() + }; + return Ok((normal_text, tools)); } // If no results and text starts with '{', try plain JSON if text.trim_start().starts_with('{') { // Create a temporary plain JSON parser let plain_parser = JsonParser::new(); - return plain_parser.parse_complete(text).await; + let (_json_normal_text, tools) = plain_parser.parse_complete(text).await?; + // For plain JSON, don't extract normal text (consistent with JsonParser behavior) + return Ok((String::new(), tools)); } - Ok(vec![]) + // No tool calls found, return original text as normal text + Ok((text.to_string(), vec![])) } async fn parse_incremental( @@ -99,10 +109,11 @@ mod tests { let parser = LlamaParser::new(); let input = r#"<|python_tag|>{"name": "search", "arguments": {"query": "weather"}}"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "search"); - assert!(result[0].function.arguments.contains("weather")); + let (normal_text, tool_calls) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tool_calls.len(), 1); + assert_eq!(tool_calls[0].function.name, "search"); + assert!(tool_calls[0].function.arguments.contains("weather")); + assert_eq!(normal_text, ""); // Pure python_tag with JSON should have no normal text } #[tokio::test] @@ -110,9 +121,10 @@ mod tests { let parser = LlamaParser::new(); let input = r#"{"name": "calculate", "arguments": {"x": 5, "y": 10}}"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "calculate"); + let (normal_text, tool_calls) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tool_calls.len(), 1); + assert_eq!(tool_calls[0].function.name, "calculate"); + assert_eq!(normal_text, ""); // Pure JSON should have no normal text } #[tokio::test] @@ -120,9 +132,10 @@ mod tests { let parser = LlamaParser::new(); let input = r#"Let me help you with that. <|python_tag|>{"name": "get_time", "arguments": {"timezone": "UTC"}}"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "get_time"); + let (normal_text, tool_calls) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tool_calls.len(), 1); + assert_eq!(tool_calls[0].function.name, "get_time"); + assert_eq!(normal_text, "Let me help you with that. "); } #[test] @@ -141,15 +154,15 @@ mod tests { // Note: Llama 3.2 doesn't handle multiple calls well let input = r#"<|python_tag|>{"name": "func1", "arguments": {"x": 1}};"#; - let result = parser.parse_complete(input).await.unwrap(); + let (_normal_text, tool_calls) = parser.parse_complete(input).await.unwrap(); // We expect this to either parse the first JSON object or fail gracefully // Since the semicolon makes it invalid JSON, it will likely return empty // This is acceptable as Llama 3.2 doesn't reliably support parallel calls // If it parses anything, it should be func1 - if !result.is_empty() { - assert_eq!(result[0].function.name, "func1"); + if !tool_calls.is_empty() { + assert_eq!(tool_calls[0].function.name, "func1"); } } } diff --git a/sgl-router/src/tool_parser/parsers/mistral_parser.rs b/sgl-router/src/tool_parser/parsers/mistral_parser.rs index 68a3568aaf0c..e3b1369515ec 100644 --- a/sgl-router/src/tool_parser/parsers/mistral_parser.rs +++ b/sgl-router/src/tool_parser/parsers/mistral_parser.rs @@ -38,6 +38,10 @@ impl MistralParser { /// - Escape sequences /// - Bracket depth fn extract_json_array<'a>(&self, text: &'a str) -> Option<&'a str> { + self.extract_json_array_with_pos(text).map(|(_, json)| json) + } + + fn extract_json_array_with_pos<'a>(&self, text: &'a str) -> Option<(usize, &'a str)> { const BOT_TOKEN: &str = "[TOOL_CALLS] ["; // Find the start of the token @@ -78,7 +82,7 @@ impl MistralParser { bracket_count -= 1; if bracket_count == 0 { // Found the matching closing bracket - return Some(&text[json_start..=i]); + return Some((start_idx, &text[json_start..=i])); } } } @@ -154,18 +158,31 @@ impl Default for MistralParser { #[async_trait] impl ToolParser for MistralParser { - async fn parse_complete(&self, text: &str) -> ToolParserResult> { + async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec)> { // Check if text contains Mistral format if !self.has_tool_markers(text) { - return Ok(vec![]); + return Ok((text.to_string(), vec![])); } - // Extract JSON array from Mistral format - if let Some(json_array) = self.extract_json_array(text) { - self.parse_json_array(json_array) + // Extract JSON array from Mistral format with position + if let Some((start_idx, json_array)) = self.extract_json_array_with_pos(text) { + // Extract normal text before BOT_TOKEN + let normal_text_before = if start_idx > 0 { + text[..start_idx].to_string() + } else { + String::new() + }; + + match self.parse_json_array(json_array) { + Ok(tools) => Ok((normal_text_before, tools)), + Err(_) => { + // If JSON parsing fails, return the original text as normal text + Ok((text.to_string(), vec![])) + } + } } else { // Markers present but no complete array found - Ok(vec![]) + Ok((text.to_string(), vec![])) } } @@ -291,10 +308,10 @@ mod tests { let parser = MistralParser::new(); let input = r#"[TOOL_CALLS] [{"name": "get_weather", "arguments": {"location": "Paris", "units": "celsius"}}]"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "get_weather"); - assert!(result[0].function.arguments.contains("Paris")); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "get_weather"); + assert!(tools[0].function.arguments.contains("Paris")); } #[tokio::test] @@ -305,10 +322,10 @@ mod tests { {"name": "calculate", "arguments": {"expression": "2 + 2"}} ]"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 2); - assert_eq!(result[0].function.name, "search"); - assert_eq!(result[1].function.name, "calculate"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 2); + assert_eq!(tools[0].function.name, "search"); + assert_eq!(tools[1].function.name, "calculate"); } #[tokio::test] @@ -316,11 +333,11 @@ mod tests { let parser = MistralParser::new(); let input = r#"[TOOL_CALLS] [{"name": "process", "arguments": {"data": [1, 2, [3, 4]], "config": {"nested": [5, 6]}}}]"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "process"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "process"); // JSON serialization removes spaces, so check for [3,4] without spaces - assert!(result[0].function.arguments.contains("[3,4]")); + assert!(tools[0].function.arguments.contains("[3,4]")); } #[tokio::test] @@ -328,9 +345,9 @@ mod tests { let parser = MistralParser::new(); let input = r#"[TOOL_CALLS] [{"name": "echo", "arguments": {"message": "He said \"Hello [World]\""}}]"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "echo"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "echo"); } #[test] diff --git a/sgl-router/src/tool_parser/parsers/pythonic_parser.rs b/sgl-router/src/tool_parser/parsers/pythonic_parser.rs index 79a9bf942a7e..02db5e021f90 100644 --- a/sgl-router/src/tool_parser/parsers/pythonic_parser.rs +++ b/sgl-router/src/tool_parser/parsers/pythonic_parser.rs @@ -45,7 +45,8 @@ impl PythonicParser { } /// Extract tool calls using bracket counting (similar to MistralParser) - fn extract_tool_calls(&self, text: &str) -> Option { + /// Returns extracted tool call group with [] and normal content + fn extract_tool_calls(&self, text: &str) -> Option<(String, String)> { // Find the start of a tool call list - look for [ followed by a function name let chars: Vec = text.chars().collect(); @@ -103,7 +104,11 @@ impl PythonicParser { // Found the matching bracket let extracted: String = chars[start_idx..=i].iter().collect(); if extracted.contains('(') && extracted.contains(')') { - return Some(extracted); + // Calculate normal text by removing the tool call portion + let before = &text[..start_idx]; + let after = &text[(i + 1)..]; + let normal_text = format!("{}{}", before, after); + return Some((extracted, normal_text)); } } } @@ -260,11 +265,11 @@ impl PythonicParser { #[async_trait] impl ToolParser for PythonicParser { - async fn parse_complete(&self, text: &str) -> ToolParserResult> { + async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec)> { let cleaned = Self::strip_special_tokens(text); // Extract tool calls using bracket counting - if let Some(tool_calls_text) = self.extract_tool_calls(&cleaned) { + if let Some((tool_calls_text, normal_text)) = self.extract_tool_calls(&cleaned) { // Remove the outer brackets let tool_calls_str = &tool_calls_text[1..tool_calls_text.len() - 1]; @@ -318,9 +323,9 @@ impl ToolParser for PythonicParser { } } - Ok(calls) + Ok((normal_text, calls)) } else { - Ok(vec![]) + Ok((text.to_string(), vec![])) } } @@ -336,11 +341,11 @@ impl ToolParser for PythonicParser { // Try to parse if we have a complete tool call let cleaned = Self::strip_special_tokens(&state.buffer); if self.extract_tool_calls(&cleaned).is_some() { - let result = self.parse_complete(&state.buffer).await?; - if !result.is_empty() { + let (_normal_text, tools) = self.parse_complete(&state.buffer).await?; + if !tools.is_empty() { state.buffer.clear(); return Ok(StreamResult::ToolComplete( - result.into_iter().next().unwrap(), + tools.into_iter().next().unwrap(), )); } } @@ -369,11 +374,11 @@ mod tests { let parser = PythonicParser::new(); let input = r#"[search_web(query="Rust programming", max_results=5)]"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "search_web"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "search_web"); - let args: Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args: Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert_eq!(args["query"], "Rust programming"); assert_eq!(args["max_results"], 5); } @@ -383,10 +388,10 @@ mod tests { let parser = PythonicParser::new(); let input = r#"[get_weather(city="Tokyo"), search(query="news")]"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 2); - assert_eq!(result[0].function.name, "get_weather"); - assert_eq!(result[1].function.name, "search"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 2); + assert_eq!(tools[0].function.name, "get_weather"); + assert_eq!(tools[1].function.name, "search"); } #[tokio::test] @@ -394,10 +399,10 @@ mod tests { let parser = PythonicParser::new(); let input = r#"[test(flag=True, disabled=False, optional=None)]"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); - let args: Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args: Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert_eq!(args["flag"], true); assert_eq!(args["disabled"], false); assert_eq!(args["optional"], Value::Null); @@ -408,11 +413,11 @@ mod tests { let parser = PythonicParser::new(); let input = r#"<|python_start|>[calculate(x=10, y=20)]<|python_end|>"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "calculate"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "calculate"); - let args: Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args: Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert_eq!(args["x"], 10); assert_eq!(args["y"], 20); } @@ -422,12 +427,41 @@ mod tests { let parser = PythonicParser::new(); let input = r#"[get_weather(city="London", units="celsius")]"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "get_weather"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "get_weather"); - let args: Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args: Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert_eq!(args["city"], "London"); assert_eq!(args["units"], "celsius"); } + + #[tokio::test] + async fn test_normal_text_extraction() { + let parser = PythonicParser::new(); + + // Test with text before and after + let input = r#"Please check the weather [get_weather(city="Tokyo")] and let me know."#; + let (normal_text, tools) = parser.parse_complete(input).await.unwrap(); + + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "get_weather"); + assert_eq!(normal_text, "Please check the weather and let me know."); + + // Test with only normal text (no tool calls) + let input_no_tools = "This is just normal text without any tool calls."; + let (normal_text, tools) = parser.parse_complete(input_no_tools).await.unwrap(); + + assert_eq!(tools.len(), 0); + assert_eq!(normal_text, input_no_tools); + + // Test with multiple tool calls in single bracket group and normal text + let input_multiple = r#"First, [search(query="rust"), calculate(x=5, y=10)] please."#; + let (normal_text, tools) = parser.parse_complete(input_multiple).await.unwrap(); + + assert_eq!(tools.len(), 2); + assert_eq!(tools[0].function.name, "search"); + assert_eq!(tools[1].function.name, "calculate"); + assert_eq!(normal_text, "First, please."); + } } diff --git a/sgl-router/src/tool_parser/parsers/qwen_parser.rs b/sgl-router/src/tool_parser/parsers/qwen_parser.rs index 29ad2083c802..bb0a2f462768 100644 --- a/sgl-router/src/tool_parser/parsers/qwen_parser.rs +++ b/sgl-router/src/tool_parser/parsers/qwen_parser.rs @@ -128,32 +128,51 @@ impl Default for QwenParser { #[async_trait] impl ToolParser for QwenParser { - async fn parse_complete(&self, text: &str) -> ToolParserResult> { + async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec)> { // Check if text contains Qwen format if !self.has_tool_markers(text) { - return Ok(vec![]); + return Ok((text.to_string(), vec![])); } - // Extract all tool call blocks - let tool_blocks = self.extract_tool_calls(text); + // Collect matches with positions and parse tools in one pass + let matches: Vec<_> = self.extractor.captures_iter(text).collect(); let mut tools = Vec::new(); - for (index, json_str) in tool_blocks.iter().enumerate() { - // Parse each JSON block - match serde_json::from_str::(json_str.trim()) { - Ok(value) => { - if let Some(tool) = self.parse_single_object(&value, index)? { - tools.push(tool); + for (index, captures) in matches.iter().enumerate() { + if let Some(json_str) = captures.get(1) { + match serde_json::from_str::(json_str.as_str().trim()) { + Ok(value) => { + if let Some(tool) = self.parse_single_object(&value, index)? { + tools.push(tool); + } + } + Err(_) => { + // JSON parsing failed, might be incomplete } - } - Err(_) => { - // Skip malformed JSON blocks - continue; } } } - Ok(tools) + // Extract normal text using first and last match positions + let normal_text = if tools.is_empty() { + text.to_string() + } else { + let first_start = matches[0].get(0).unwrap().start(); + let last_end = matches.last().unwrap().get(0).unwrap().end(); + let before = if first_start > 0 { + &text[..first_start] + } else { + "" + }; + let after = if last_end < text.len() { + &text[last_end..] + } else { + "" + }; + format!("{}{}", before, after) + }; + + Ok((normal_text, tools)) } async fn parse_incremental( @@ -276,10 +295,11 @@ mod tests { {"name": "get_weather", "arguments": {"location": "Beijing", "units": "celsius"}} "#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "get_weather"); - assert!(result[0].function.arguments.contains("Beijing")); + let (normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "get_weather"); + assert!(tools[0].function.arguments.contains("Beijing")); + assert_eq!(normal_text, ""); // Pure tool call, no normal text } #[tokio::test] @@ -292,10 +312,11 @@ mod tests { {"name": "calculate", "arguments": {"expression": "2 + 2"}} "#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 2); - assert_eq!(result[0].function.name, "search"); - assert_eq!(result[1].function.name, "calculate"); + let (normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 2); + assert_eq!(tools[0].function.name, "search"); + assert_eq!(tools[1].function.name, "calculate"); + assert_eq!(normal_text, ""); // Pure tool calls, no normal text } #[tokio::test] @@ -307,9 +328,13 @@ mod tests { Here are the results."#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "get_info"); + let (normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "get_info"); + assert_eq!( + normal_text, + "Let me help you with that.\n\nHere are the results." + ); } #[tokio::test] @@ -329,10 +354,11 @@ Here are the results."#; } "#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "process_data"); - assert!(result[0].function.arguments.contains("nested")); + let (normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "process_data"); + assert!(tools[0].function.arguments.contains("nested")); + assert_eq!(normal_text, ""); // Pure tool call, no normal text } #[test] diff --git a/sgl-router/src/tool_parser/parsers/step3_parser.rs b/sgl-router/src/tool_parser/parsers/step3_parser.rs index 721d5c03759c..c2ffd1e61714 100644 --- a/sgl-router/src/tool_parser/parsers/step3_parser.rs +++ b/sgl-router/src/tool_parser/parsers/step3_parser.rs @@ -157,10 +157,10 @@ impl Default for Step3Parser { #[async_trait] impl ToolParser for Step3Parser { - async fn parse_complete(&self, text: &str) -> ToolParserResult> { + async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec)> { // Check if text contains Step3 format if !self.has_tool_markers(text) { - return Ok(vec![]); + return Ok((text.to_string(), vec![])); } // Find the tool calls section @@ -170,6 +170,7 @@ impl ToolParser for Step3Parser { // Find the end of tool calls section if let Some(end_pos) = text[search_from..].find("<|tool_calls_end|>") { let tool_section = &text[search_from..search_from + end_pos]; + let end_abs = search_from + end_pos + "<|tool_calls_end|>".len(); // Extract all tool call blocks let mut tools = Vec::new(); @@ -179,11 +180,24 @@ impl ToolParser for Step3Parser { } } - return Ok(tools); + // Extract normal text before start and after end + let before = if start_pos > 0 { + &text[..start_pos] + } else { + "" + }; + let after = if end_abs < text.len() { + &text[end_abs..] + } else { + "" + }; + let normal_text = format!("{}{}", before, after); + + return Ok((normal_text, tools)); } } - Ok(vec![]) + Ok((text.to_string(), vec![])) } async fn parse_incremental( @@ -289,11 +303,11 @@ mod tests { <|tool_call_end|> <|tool_calls_end|>More text"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "get_weather"); - assert!(result[0].function.arguments.contains("Tokyo")); - assert!(result[0].function.arguments.contains("celsius")); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "get_weather"); + assert!(tools[0].function.arguments.contains("Tokyo")); + assert!(tools[0].function.arguments.contains("celsius")); } #[tokio::test] @@ -308,10 +322,10 @@ mod tests { <|tool_call_end|> <|tool_calls_end|>"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 2); - assert_eq!(result[0].function.name, "search"); - assert_eq!(result[1].function.name, "calculate"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 2); + assert_eq!(tools[0].function.name, "search"); + assert_eq!(tools[1].function.name, "calculate"); } #[tokio::test] @@ -326,12 +340,12 @@ mod tests { <|tool_call_end|> <|tool_calls_end|>"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "process_data"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "process_data"); // Parse arguments to check types - let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert_eq!(args["count"], 42); assert_eq!(args["active"], true); assert_eq!(args["rate"], 1.5); diff --git a/sgl-router/src/tool_parser/tests.rs b/sgl-router/src/tool_parser/tests.rs index 239026bfe129..65d35933740f 100644 --- a/sgl-router/src/tool_parser/tests.rs +++ b/sgl-router/src/tool_parser/tests.rs @@ -242,12 +242,12 @@ async fn test_json_parser_complete_single() { let parser = JsonParser::new(); let input = r#"{"name": "get_weather", "arguments": {"location": "San Francisco", "units": "celsius"}}"#; - let result = parser.parse_complete(input).await.unwrap(); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "get_weather"); - assert!(result[0].function.arguments.contains("San Francisco")); - assert!(result[0].function.arguments.contains("celsius")); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "get_weather"); + assert!(tools[0].function.arguments.contains("San Francisco")); + assert!(tools[0].function.arguments.contains("celsius")); } #[tokio::test] @@ -259,11 +259,11 @@ async fn test_json_parser_complete_array() { {"name": "get_news", "arguments": {"query": "technology"}} ]"#; - let result = parser.parse_complete(input).await.unwrap(); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 2); - assert_eq!(result[0].function.name, "get_weather"); - assert_eq!(result[1].function.name, "get_news"); + assert_eq!(tools.len(), 2); + assert_eq!(tools[0].function.name, "get_weather"); + assert_eq!(tools[1].function.name, "get_news"); } #[tokio::test] @@ -271,13 +271,13 @@ async fn test_json_parser_with_parameters() { let parser = JsonParser::new(); let input = r#"{"name": "calculate", "parameters": {"x": 10, "y": 20, "operation": "add"}}"#; - let result = parser.parse_complete(input).await.unwrap(); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "calculate"); - assert!(result[0].function.arguments.contains("10")); - assert!(result[0].function.arguments.contains("20")); - assert!(result[0].function.arguments.contains("add")); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "calculate"); + assert!(tools[0].function.arguments.contains("10")); + assert!(tools[0].function.arguments.contains("20")); + assert!(tools[0].function.arguments.contains("add")); } #[tokio::test] @@ -289,10 +289,10 @@ async fn test_json_parser_with_tokens() { }); let input = r#"[TOOL_CALLS] [{"name": "search", "arguments": {"query": "rust programming"}}]"#; - let result = parser.parse_complete(input).await.unwrap(); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "search"); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "search"); } #[tokio::test] @@ -313,12 +313,12 @@ async fn test_multiline_json_with_tokens() { } }"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "get_weather"); - assert!(result[0].function.arguments.contains("San Francisco")); - assert!(result[0].function.arguments.contains("celsius")); - assert!(result[0].function.arguments.contains("true")); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "get_weather"); + assert!(tools[0].function.arguments.contains("San Francisco")); + assert!(tools[0].function.arguments.contains("celsius")); + assert!(tools[0].function.arguments.contains("true")); } #[tokio::test] @@ -342,12 +342,12 @@ async fn test_multiline_json_array() { } ]"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 2); - assert_eq!(result[0].function.name, "function1"); - assert_eq!(result[1].function.name, "function2"); - assert!(result[0].function.arguments.contains("value1")); - assert!(result[1].function.arguments.contains("[1,2,3]")); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 2); + assert_eq!(tools[0].function.name, "function1"); + assert_eq!(tools[1].function.name, "function2"); + assert!(tools[0].function.arguments.contains("value1")); + assert!(tools[1].function.arguments.contains("[1,2,3]")); } #[test] @@ -397,9 +397,9 @@ async fn test_registry_with_json_parser() { let parser = registry.get_parser("gpt-4-turbo").unwrap(); let input = r#"{"name": "test", "arguments": {"x": 1}}"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "test"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "test"); } #[tokio::test] @@ -407,9 +407,9 @@ async fn test_json_parser_invalid_input() { let parser = JsonParser::new(); // Invalid JSON should return empty results - assert_eq!(parser.parse_complete("not json").await.unwrap().len(), 0); - assert_eq!(parser.parse_complete("{invalid}").await.unwrap().len(), 0); - assert_eq!(parser.parse_complete("").await.unwrap().len(), 0); + assert_eq!(parser.parse_complete("not json").await.unwrap().1.len(), 0); + assert_eq!(parser.parse_complete("{invalid}").await.unwrap().1.len(), 0); + assert_eq!(parser.parse_complete("").await.unwrap().1.len(), 0); } #[tokio::test] @@ -418,11 +418,11 @@ async fn test_json_parser_empty_arguments() { // Tool call with no arguments let input = r#"{"name": "get_time"}"#; - let result = parser.parse_complete(input).await.unwrap(); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "get_time"); - assert_eq!(result[0].function.arguments, "{}"); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "get_time"); + assert_eq!(tools[0].function.arguments, "{}"); } #[cfg(test)] @@ -435,14 +435,14 @@ mod failure_cases { // Missing name field let input = r#"{"arguments": {"x": 1}}"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 0, "Should return empty for tool without name"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 0, "Should return empty for tool without name"); // Empty name let input = r#"{"name": "", "arguments": {"x": 1}}"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1, "Should accept empty name string"); - assert_eq!(result[0].function.name, ""); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1, "Should accept empty name string"); + assert_eq!(tools[0].function.name, ""); } #[tokio::test] @@ -451,22 +451,22 @@ mod failure_cases { // Arguments is a string instead of object let input = r#"{"name": "test", "arguments": "not an object"}"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); // Should serialize the string as JSON - assert!(result[0].function.arguments.contains("not an object")); + assert!(tools[0].function.arguments.contains("not an object")); // Arguments is a number let input = r#"{"name": "test", "arguments": 42}"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.arguments, "42"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.arguments, "42"); // Arguments is null let input = r#"{"name": "test", "arguments": null}"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.arguments, "null"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.arguments, "null"); } #[tokio::test] @@ -479,26 +479,26 @@ mod failure_cases { // Missing end token let input = r#"{"name": "test", "arguments": {}}"#; - let result = parser.parse_complete(input).await.unwrap(); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); assert_eq!( - result.len(), + tools.len(), 0, "Should fail to parse without complete wrapper" ); // Missing start token - parser looks for complete wrapper, so this won't parse let input = r#"{"name": "test", "arguments": {}}"#; - let result = parser.parse_complete(input).await.unwrap(); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); assert_eq!( - result.len(), + tools.len(), 0, "Should not parse JSON with incomplete wrapper" ); // Mismatched tokens let input = r#"{"name": "test", "arguments": {}}"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 0, "Should fail with mismatched tokens"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 0, "Should fail with mismatched tokens"); } #[tokio::test] @@ -507,18 +507,18 @@ mod failure_cases { // Trailing comma let input = r#"{"name": "test", "arguments": {"x": 1,}}"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 0, "Should reject JSON with trailing comma"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 0, "Should reject JSON with trailing comma"); // Missing quotes on keys let input = r#"{name: "test", arguments: {}}"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 0, "Should reject invalid JSON syntax"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 0, "Should reject invalid JSON syntax"); // Unclosed object let input = r#"{"name": "test", "arguments": {"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 0, "Should reject incomplete JSON"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 0, "Should reject incomplete JSON"); } } @@ -532,17 +532,17 @@ mod edge_cases { // Unicode in function name let input = r#"{"name": "获取天气", "arguments": {"location": "北京"}}"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "获取天气"); - assert!(result[0].function.arguments.contains("北京")); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "获取天气"); + assert!(tools[0].function.arguments.contains("北京")); // Emoji in arguments let input = r#"{"name": "send_message", "arguments": {"text": "Hello 👋 World 🌍"}}"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert!(result[0].function.arguments.contains("👋")); - assert!(result[0].function.arguments.contains("🌍")); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert!(tools[0].function.arguments.contains("👋")); + assert!(tools[0].function.arguments.contains("🌍")); } #[tokio::test] @@ -551,22 +551,22 @@ mod edge_cases { // Escaped quotes in arguments let input = r#"{"name": "echo", "arguments": {"text": "He said \"hello\""}}"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert!(result[0].function.arguments.contains(r#"\"hello\""#)); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert!(tools[0].function.arguments.contains(r#"\"hello\""#)); // Escaped backslashes let input = r#"{"name": "path", "arguments": {"dir": "C:\\Users\\test"}}"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert!(result[0].function.arguments.contains("\\\\")); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert!(tools[0].function.arguments.contains("\\\\")); // Newlines and tabs let input = r#"{"name": "format", "arguments": {"text": "line1\nline2\ttabbed"}}"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert!(result[0].function.arguments.contains("\\n")); - assert!(result[0].function.arguments.contains("\\t")); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert!(tools[0].function.arguments.contains("\\n")); + assert!(tools[0].function.arguments.contains("\\t")); } #[tokio::test] @@ -580,10 +580,10 @@ mod edge_cases { } large_args.push_str(r#""final": "value"}}"#); - let result = parser.parse_complete(&large_args).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "process"); - assert!(result[0].function.arguments.contains("field_999")); + let (_normal_text, tools) = parser.parse_complete(&large_args).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "process"); + assert!(tools[0].function.arguments.contains("field_999")); // Large array of tool calls let mut large_array = "[".to_string(); @@ -595,9 +595,9 @@ mod edge_cases { } large_array.push(']'); - let result = parser.parse_complete(&large_array).await.unwrap(); - assert_eq!(result.len(), 100); - assert_eq!(result[99].function.name, "func_99"); + let (_normal_text, tools) = parser.parse_complete(&large_array).await.unwrap(); + assert_eq!(tools.len(), 100); + assert_eq!(tools[99].function.name, "func_99"); } #[tokio::test] @@ -612,10 +612,10 @@ mod edge_cases { {"key": "value", "another": "field"} ]"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 2, "Should only parse valid tool calls"); - assert_eq!(result[0].function.name, "tool1"); - assert_eq!(result[1].function.name, "tool2"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 2, "Should only parse valid tool calls"); + assert_eq!(tools[0].function.name, "tool1"); + assert_eq!(tools[1].function.name, "tool2"); } #[tokio::test] @@ -624,14 +624,14 @@ mod edge_cases { // JSON with duplicate keys (last one wins in most parsers) let input = r#"{"name": "first", "name": "second", "arguments": {"x": 1, "x": 2}}"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); assert_eq!( - result[0].function.name, "second", + tools[0].function.name, "second", "Last duplicate key should win" ); assert!( - result[0].function.arguments.contains("2"), + tools[0].function.arguments.contains("2"), "Last duplicate value should win" ); } @@ -642,15 +642,15 @@ mod edge_cases { // Null values in arguments let input = r#"{"name": "test", "arguments": {"required": "value", "optional": null}}"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert!(result[0].function.arguments.contains("null")); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert!(tools[0].function.arguments.contains("null")); // Array with null let input = r#"{"name": "test", "arguments": {"items": [1, null, "three"]}}"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert!(result[0].function.arguments.contains("null")); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert!(tools[0].function.arguments.contains("null")); } #[tokio::test] @@ -663,22 +663,22 @@ mod edge_cases { // First pattern let input = r#"<<{"name": "test1", "arguments": {}}>>"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "test1"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "test1"); // Second pattern let input = r#"{"name": "test2", "arguments": {}}"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "test2"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "test2"); // Nested patterns (should use first match) let input = r#"<{"name": "test3", "arguments": {}}>"#; - let result = parser.parse_complete(input).await.unwrap(); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); // This is tricky - depends on regex behavior // The parser should handle this gracefully - assert!(result.len() <= 1, "Should not parse multiple times"); + assert!(tools.len() <= 1, "Should not parse multiple times"); } #[tokio::test] @@ -743,25 +743,25 @@ mod edge_cases { // Boolean values let input = r#"{"name": "toggle", "arguments": {"enabled": true, "disabled": false}}"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert!(result[0].function.arguments.contains("true")); - assert!(result[0].function.arguments.contains("false")); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert!(tools[0].function.arguments.contains("true")); + assert!(tools[0].function.arguments.contains("false")); // Numbers (including float and negative) let input = r#"{"name": "calc", "arguments": {"int": 42, "float": 3.14, "negative": -17}}"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert!(result[0].function.arguments.contains("42")); - assert!(result[0].function.arguments.contains("3.14")); - assert!(result[0].function.arguments.contains("-17")); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert!(tools[0].function.arguments.contains("42")); + assert!(tools[0].function.arguments.contains("3.14")); + assert!(tools[0].function.arguments.contains("-17")); // Empty arrays and objects let input = r#"{"name": "test", "arguments": {"empty_arr": [], "empty_obj": {}}}"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert!(result[0].function.arguments.contains("[]")); - assert!(result[0].function.arguments.contains("{}")); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert!(tools[0].function.arguments.contains("[]")); + assert!(tools[0].function.arguments.contains("{}")); } #[tokio::test] @@ -770,15 +770,15 @@ mod edge_cases { // Using "function" instead of "name" let input = r#"{"function": "test_func", "arguments": {"x": 1}}"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "test_func"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "test_func"); // Both "name" and "function" present (name should take precedence) let input = r#"{"name": "primary", "function": "secondary", "arguments": {}}"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "primary"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "primary"); } #[tokio::test] @@ -792,15 +792,15 @@ mod edge_cases { "key" : "value" } } "#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "test"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "test"); // Minified JSON (no whitespace) let input = r#"{"name":"compact","arguments":{"a":1,"b":2}}"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "compact"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "compact"); } } @@ -830,9 +830,9 @@ mod stress_tests { } }"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert!(result[0].function.arguments.contains("deep")); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert!(tools[0].function.arguments.contains("deep")); } #[tokio::test] @@ -845,9 +845,9 @@ mod stress_tests { let parser_clone = parser.clone(); let handle = tokio::spawn(async move { let input = format!(r#"{{"name": "func_{}", "arguments": {{}}}}"#, i); - let result = parser_clone.parse_complete(&input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, format!("func_{}", i)); + let (_normal_text, tools) = parser_clone.parse_complete(&input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, format!("func_{}", i)); }); handles.push(handle); } diff --git a/sgl-router/src/tool_parser/traits.rs b/sgl-router/src/tool_parser/traits.rs index 19263688d6fd..34b097a3f533 100644 --- a/sgl-router/src/tool_parser/traits.rs +++ b/sgl-router/src/tool_parser/traits.rs @@ -9,7 +9,8 @@ use async_trait::async_trait; #[async_trait] pub trait ToolParser: Send + Sync { /// Parse complete tool calls from final output - async fn parse_complete(&self, output: &str) -> ToolParserResult>; + /// Returns (remaining_normal_text, tool_calls) tuple + async fn parse_complete(&self, output: &str) -> ToolParserResult<(String, Vec)>; /// Parse tool calls from model output (streaming) async fn parse_incremental( diff --git a/sgl-router/tests/tool_parser_deepseek.rs b/sgl-router/tests/tool_parser_deepseek.rs index 01738dcafe50..f33c7c813248 100644 --- a/sgl-router/tests/tool_parser_deepseek.rs +++ b/sgl-router/tests/tool_parser_deepseek.rs @@ -13,11 +13,11 @@ async fn test_deepseek_complete_parsing() { ```<|tool▁call▁end|><|tool▁calls▁end|> The weather in Tokyo is..."#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "get_weather"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "get_weather"); - let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert_eq!(args["location"], "Tokyo"); assert_eq!(args["units"], "celsius"); } @@ -37,10 +37,10 @@ async fn test_deepseek_multiple_tools() { ```<|tool▁call▁end|> <|tool▁calls▁end|>"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 2); - assert_eq!(result[0].function.name, "search"); - assert_eq!(result[1].function.name, "translate"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 2); + assert_eq!(tools[0].function.name, "search"); + assert_eq!(tools[1].function.name, "translate"); } #[tokio::test] @@ -96,11 +96,11 @@ async fn test_deepseek_nested_json() { } ```<|tool▁call▁end|><|tool▁calls▁end|>"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "process"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "process"); - let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert!(args["data"]["nested"]["deep"].is_array()); } @@ -134,10 +134,10 @@ async fn test_deepseek_malformed_json_handling() { ```<|tool▁call▁end|> <|tool▁calls▁end|>"#; - let result = parser.parse_complete(input).await.unwrap(); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); // Only the valid tool call should be parsed - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "valid"); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "valid"); } #[tokio::test] @@ -151,9 +151,9 @@ async fn test_normal_text_extraction() { {"location": "Tokyo"} ```<|tool▁call▁end|><|tool▁calls▁end|>"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "get_weather"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "get_weather"); // TODO: Verify normal text extraction when parser returns it // In Python: normal_text = "Let me help you with that." @@ -174,8 +174,8 @@ async fn test_multiple_tool_calls() { ```<|tool▁call▁end|> <|tool▁calls▁end|><|end▁of▁sentence|>"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 2); - assert_eq!(result[0].function.name, "get_weather"); - assert_eq!(result[1].function.name, "get_weather"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 2); + assert_eq!(tools[0].function.name, "get_weather"); + assert_eq!(tools[1].function.name, "get_weather"); } diff --git a/sgl-router/tests/tool_parser_edge_cases.rs b/sgl-router/tests/tool_parser_edge_cases.rs index 04bcab1f385c..96ae606b6466 100644 --- a/sgl-router/tests/tool_parser_edge_cases.rs +++ b/sgl-router/tests/tool_parser_edge_cases.rs @@ -16,9 +16,9 @@ async fn test_empty_input() { let parser = registry .get_parser(&format!("test-{}", parser_name)) .unwrap(); - let result = parser.parse_complete("").await.unwrap(); + let (_normal_text, tools) = parser.parse_complete("").await.unwrap(); assert_eq!( - result.len(), + tools.len(), 0, "Parser {} should return empty for empty input", parser_name @@ -32,7 +32,12 @@ async fn test_plain_text_no_tools() { let json_parser = JsonParser::new(); assert_eq!( - json_parser.parse_complete(plain_text).await.unwrap().len(), + json_parser + .parse_complete(plain_text) + .await + .unwrap() + .1 + .len(), 0 ); @@ -42,13 +47,19 @@ async fn test_plain_text_no_tools() { .parse_complete(plain_text) .await .unwrap() + .1 .len(), 0 ); let qwen_parser = QwenParser::new(); assert_eq!( - qwen_parser.parse_complete(plain_text).await.unwrap().len(), + qwen_parser + .parse_complete(plain_text) + .await + .unwrap() + .1 + .len(), 0 ); @@ -58,6 +69,7 @@ async fn test_plain_text_no_tools() { .parse_complete(plain_text) .await .unwrap() + .1 .len(), 0 ); @@ -74,9 +86,9 @@ async fn test_incomplete_json() { ]; for input in incomplete_cases { - let result = json_parser.parse_complete(input).await.unwrap(); + let (_normal_text, tools) = json_parser.parse_complete(input).await.unwrap(); assert_eq!( - result.len(), + tools.len(), 0, "Should not parse incomplete JSON: {}", input @@ -106,9 +118,9 @@ async fn test_malformed_mistral() { for input in malformed_cases { // Parser might return error or empty vec for malformed input - if let Ok(result) = parser.parse_complete(input).await { + if let Ok((_normal_text, tools)) = parser.parse_complete(input).await { assert_eq!( - result.len(), + tools.len(), 0, "Should not parse malformed Mistral: {}", input @@ -124,13 +136,13 @@ async fn test_missing_required_fields() { // Missing name field let input = r#"{"arguments": {"x": 1}}"#; - let result = json_parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 0, "Should not parse without name field"); + let (_normal_text, tools) = json_parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 0, "Should not parse without name field"); // Name is not a string let input = r#"{"name": 123, "arguments": {}}"#; - let result = json_parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 0, "Should not parse with non-string name"); + let (_normal_text, tools) = json_parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 0, "Should not parse with non-string name"); } #[tokio::test] @@ -143,11 +155,11 @@ async fn test_very_long_strings() { long_string ); - let result = json_parser.parse_complete(&input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "test"); + let (_normal_text, tools) = json_parser.parse_complete(&input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "test"); - let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert_eq!(args["data"].as_str().unwrap().len(), 10000); } @@ -158,10 +170,10 @@ async fn test_unicode_edge_cases() { // Various Unicode characters including emojis, CJK, RTL text let input = r#"{"name": "translate", "arguments": {"text": "Hello 世界 🌍 مرحبا עולם"}}"#; - let result = json_parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); + let (_normal_text, tools) = json_parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); - let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert_eq!(args["text"], "Hello 世界 🌍 مرحبا עולם"); } @@ -169,16 +181,16 @@ async fn test_unicode_edge_cases() { async fn test_nested_brackets_in_strings() { let mistral_parser = MistralParser::new(); let input = r#"[TOOL_CALLS] [{"name": "echo", "arguments": {"text": "Array: [1, 2, 3]"}}]"#; - let result = mistral_parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let (_normal_text, tools) = mistral_parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert_eq!(args["text"], "Array: [1, 2, 3]"); let pythonic_parser = PythonicParser::new(); let input = r#"[echo(text="List: [a, b, c]")]"#; - let result = pythonic_parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let (_normal_text, tools) = pythonic_parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert_eq!(args["text"], "List: [a, b, c]"); } @@ -191,9 +203,9 @@ async fn test_multiple_formats_in_text() { And some more text with tags. "#; - let result = json_parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "actual_tool"); + let (_normal_text, tools) = json_parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "actual_tool"); } #[tokio::test] @@ -202,10 +214,10 @@ async fn test_escaped_characters() { let input = r#"{"name": "write", "arguments": {"content": "Line 1\nLine 2\r\nLine 3\tTabbed\\Backslash\"Quote"}}"#; - let result = json_parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); + let (_normal_text, tools) = json_parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); - let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); let content = args["content"].as_str().unwrap(); assert!(content.contains('\n')); assert!(content.contains('\t')); @@ -229,10 +241,10 @@ async fn test_numeric_edge_cases() { } }"#; - let result = json_parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); + let (_normal_text, tools) = json_parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); - let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert_eq!(args["int"], 42); assert_eq!(args["float"], 123.456); assert_eq!(args["scientific"], 0.000123); @@ -254,10 +266,10 @@ async fn test_null_and_boolean_values() { } }"#; - let result = json_parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); + let (_normal_text, tools) = json_parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); - let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert_eq!(args["enabled"], true); assert_eq!(args["disabled"], false); assert_eq!(args["optional"], serde_json::Value::Null); diff --git a/sgl-router/tests/tool_parser_fallback.rs b/sgl-router/tests/tool_parser_fallback.rs new file mode 100644 index 000000000000..e522a1116fe5 --- /dev/null +++ b/sgl-router/tests/tool_parser_fallback.rs @@ -0,0 +1,267 @@ +//! Tests for tool parser fallback behavior +//! +//! When tool call parsing fails, the original text should be preserved as normal text +//! rather than being lost. This ensures graceful degradation. + +use sglang_router_rs::tool_parser::{ + DeepSeekParser, JsonParser, LlamaParser, MistralParser, QwenParser, ToolParser, +}; + +#[tokio::test] +async fn test_json_parser_invalid_json_returns_as_normal_text() { + let parser = JsonParser::new(); + + // Malformed JSON should be returned as normal text (note: commas may be processed) + let input = r#"{"name": "test", "arguments": invalid json here}"#; + let (normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 0); + assert_eq!( + normal_text, + r#"{"name": "test", "arguments": invalid json here}"# + ); + + // Plain text with no JSON structure should be returned as normal text + let input = "This is just plain text that should not be parsed as a tool call"; + let (normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 0); + assert_eq!(normal_text, input); + + // Text that looks like it might have JSON but doesn't should be returned as normal text + let input = "The user said: {something} but it's not valid JSON"; + let (normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 0); + assert_eq!(normal_text, input); +} + +#[tokio::test] +async fn test_qwen_parser_invalid_format_returns_as_normal_text() { + let parser = QwenParser::new(); + + // Missing closing tag + let input = r#" +{"name": "test", "arguments": {}} +This text is missing the closing tag"#; + let (normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 0); + assert_eq!(normal_text, input); // Should preserve original text when no valid tools found + + // Malformed JSON inside valid tags + let input = r#" +{"name": "test", "arguments": invalid} +"#; + let (normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 0); + // When JSON parsing fails but tags are present, it should preserve the original text + assert_eq!(normal_text, input); + + // Plain text without any tool markers + let input = "This is a regular response without any tool calls."; + let (normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 0); + assert_eq!(normal_text, input); // Should return original text when no markers found +} + +#[tokio::test] +async fn test_llama_parser_invalid_format_returns_as_normal_text() { + let parser = LlamaParser::new(); + + // Invalid JSON after python_tag + let input = r#"<|python_tag|>{"name": "test", "arguments": invalid}"#; + let (normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 0); + assert_eq!(normal_text, input); // Should preserve original text when parsing fails + + // Plain text without markers or JSON + let input = "Just explaining something without any function calls."; + let (normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 0); + assert_eq!(normal_text, input); // Should return original text + + // Text with python_tag but completely invalid content + let input = r#"Here's my response <|python_tag|>not even close to JSON"#; + let (normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 0); + assert_eq!(normal_text, input); // Should preserve everything when parsing fails +} + +#[tokio::test] +async fn test_mistral_parser_invalid_format_returns_as_normal_text() { + let parser = MistralParser::new(); + + // Missing closing bracket + let input = r#"[TOOL_CALLS] [{"name": "test", "arguments": {}"#; + let (normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 0); + assert_eq!(normal_text, input); // Should preserve original text when parsing fails + + // Invalid JSON in tool calls section + let input = r#"[TOOL_CALLS] [{"name": invalid json}]"#; + let (normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 0); + assert_eq!(normal_text, input); // Should preserve original text when parsing fails + + // Plain text + let input = "No tool calls here, just regular text."; + let (normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 0); + assert_eq!(normal_text, input); // Should return original text +} + +#[tokio::test] +async fn test_deepseek_parser_invalid_format_returns_as_normal_text() { + let parser = DeepSeekParser::new(); + + // Invalid JSON after emoji marker + let input = r#"🤔[{"name": "test", "arguments": malformed}]"#; + let (normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 0); + assert_eq!(normal_text, input); // Should preserve original text when parsing fails + + // Emoji but no JSON array + let input = "🤔 Just thinking about this problem..."; + let (normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 0); + assert_eq!(normal_text, input); // Should return original text + + // No emoji marker at all + let input = "Regular response without any special markers."; + let (normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 0); + assert_eq!(normal_text, input); // Should return original text +} + +#[tokio::test] +async fn test_mixed_valid_and_invalid_content() { + let parser = QwenParser::new(); + + // Text with one valid tool call and one invalid + let input = r#"Let me help you with that. + +{"name": "valid_tool", "arguments": {"x": 1}} + +And here's another one: + +{"name": "invalid_tool", "arguments": malformed} + +That's all!"#; + + let (normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); // Should extract the valid tool + assert_eq!(tools[0].function.name, "valid_tool"); + // Normal text should contain the text around the valid tool call + assert!(normal_text.contains("Let me help you")); + assert!(normal_text.contains("That's all!")); +} + +#[tokio::test] +async fn test_partial_tool_markers() { + // Test cases where tool markers are incomplete or cut off + + let parser = QwenParser::new(); + let input = "\nThis looks like it might be a tool call but it's not"; + let (normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 0); + assert_eq!(normal_text, input); + + let parser = MistralParser::new(); + let input = "[TOOL_CALLS] But then nothing follows..."; + let (normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 0); + assert_eq!(normal_text, input); + + let parser = LlamaParser::new(); + let input = "Starting a response <|python_tag|> but no JSON"; + let (normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 0); + assert_eq!(normal_text, input); +} + +#[tokio::test] +async fn test_escaped_json_like_content() { + // Test that JSON-like content in regular text doesn't get parsed as tools + + let parser = JsonParser::new(); + let input = r#"The user typed: {"name": "example"} but this is just quoted text"#; + let (normal_text, tools) = parser.parse_complete(input).await.unwrap(); + // JsonParser should extract the valid JSON and return normal text + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "example"); + assert_eq!(normal_text, "The user typed: but this is just quoted text"); + + let parser = QwenParser::new(); + let input = r#"The syntax is: +{"name": "example"} + - that's how you format it"#; + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + // This actually contains valid tool call syntax, so it should parse + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "example"); +} + +#[tokio::test] +async fn test_unicode_and_special_chars_in_failed_parsing() { + let parser = QwenParser::new(); + + // Unicode in malformed tool calls + let input = r#" +{"name": "测试", "arguments": 🚀 invalid} +"#; + let (normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 0); + // Should handle Unicode properly in the fallback text + assert!(!normal_text.is_empty() || normal_text == input); + + // Special characters that might confuse parsers + let input = r#"Response: {"name": "test\n\t", "arguments": {"]}"}"#; + let (normal_text, tools) = parser.parse_complete(input).await.unwrap(); + // This might or might not parse depending on JSON handling of escape sequences + if tools.is_empty() { + assert!(!normal_text.is_empty() || normal_text == input); + } +} + +#[tokio::test] +async fn test_very_long_invalid_input() { + let parser = JsonParser::new(); + + // Generate a very long string that looks like it might be JSON but isn't + let mut input = String::from("{\"name\": \"test\", \"arguments\": {"); + for i in 0..1000 { + input.push_str(&format!("\"field{}\": \"value{}\", ", i, i)); + } + input.push_str("\"final\": incomplete"); // Don't close the JSON properly + + let (normal_text, tools) = parser.parse_complete(&input).await.unwrap(); + assert_eq!(tools.len(), 0); + assert_eq!(normal_text, input); // Invalid JSON should be returned as normal text +} + +#[tokio::test] +async fn test_almost_valid_tool_calls() { + // Test tool calls that are almost valid but have small issues + + let parser = JsonParser::new(); + + // Missing closing quote should be returned as normal text + let input = r#"{"name": "test", "arguments": {"key": "value}}"#; + let (normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 0); + assert_eq!( + normal_text, + r#"{"name": "test", "arguments": {"key": "value}}"# + ); + + // Extra comma + let input = r#"{"name": "test", "arguments": {},}"#; + let (normal_text, tools) = parser.parse_complete(input).await.unwrap(); + // Some JSON parsers might accept trailing commas + if tools.is_empty() { + assert_eq!(normal_text, r#"{"name": "test", "arguments": ,}"#); + } + + // Wrong quote types + let input = r#"{'name': 'test', 'arguments': {}}"#; + let (normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 0); // Standard JSON requires double quotes + assert_eq!(normal_text, r#"{'name': 'test', 'arguments': }"#); +} diff --git a/sgl-router/tests/tool_parser_glm4_moe.rs b/sgl-router/tests/tool_parser_glm4_moe.rs index 5fa06254b38a..477a48b33a1c 100644 --- a/sgl-router/tests/tool_parser_glm4_moe.rs +++ b/sgl-router/tests/tool_parser_glm4_moe.rs @@ -15,11 +15,11 @@ async fn test_glm4_complete_parsing() { The weather will be..."#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "get_weather"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "get_weather"); - let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert_eq!(args["city"], "Beijing"); assert_eq!(args["date"], "2024-12-25"); } @@ -39,10 +39,10 @@ async fn test_glm4_multiple_tools() { zh "#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 2); - assert_eq!(result[0].function.name, "search"); - assert_eq!(result[1].function.name, "translate"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 2); + assert_eq!(tools[0].function.name, "search"); + assert_eq!(tools[1].function.name, "translate"); } #[tokio::test] @@ -62,10 +62,10 @@ async fn test_glm4_type_conversion() { string value "#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); - let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert_eq!(args["count"], 42); assert_eq!(args["rate"], 1.5); assert_eq!(args["enabled"], true); @@ -138,10 +138,10 @@ async fn test_glm4_python_literal_values() { None "#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); - let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert_eq!(args["debug"], true); assert_eq!(args["verbose"], false); assert_eq!(args["optional"], serde_json::Value::Null); @@ -160,11 +160,11 @@ async fn test_python_literals() { None "#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "test_func"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "test_func"); - let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert_eq!(args["bool_true"], true); assert_eq!(args["bool_false"], false); assert_eq!(args["none_val"], serde_json::Value::Null); @@ -181,10 +181,10 @@ async fn test_nested_values() { [1, 2, 3] "#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); - let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert!(args["data"].is_object()); assert!(args["list"].is_array()); } diff --git a/sgl-router/tests/tool_parser_gpt_oss.rs b/sgl-router/tests/tool_parser_gpt_oss.rs index 13512a869290..de873db9232c 100644 --- a/sgl-router/tests/tool_parser_gpt_oss.rs +++ b/sgl-router/tests/tool_parser_gpt_oss.rs @@ -10,11 +10,11 @@ async fn test_gpt_oss_complete_parsing() { <|channel|>commentary to=functions.search<|constrain|>json<|message|>{"query": "rust programming", "limit": 10}<|call|> Here are the results..."#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "search"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "search"); - let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert_eq!(args["query"], "rust programming"); assert_eq!(args["limit"], 10); } @@ -26,10 +26,10 @@ async fn test_gpt_oss_multiple_tools() { let input = r#"<|channel|>commentary to=functions.get_weather<|constrain|>json<|message|>{"location": "Paris"}<|call|>commentary <|channel|>commentary to=functions.search<|constrain|>json<|message|>{"query": "Paris tourism"}<|call|>"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 2); - assert_eq!(result[0].function.name, "get_weather"); - assert_eq!(result[1].function.name, "search"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 2); + assert_eq!(tools[0].function.name, "get_weather"); + assert_eq!(tools[1].function.name, "search"); } #[tokio::test] @@ -39,10 +39,10 @@ async fn test_gpt_oss_with_namespace() { let input = r#"<|channel|>commentary to=api.users.create<|constrain|>json<|message|>{"name": "John", "email": "john@example.com"}<|call|> <|channel|>commentary to=tools.calculator.add<|constrain|>json<|message|>{"x": 10, "y": 20}<|call|>"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 2); - assert_eq!(result[0].function.name, "create"); // Should extract last part - assert_eq!(result[1].function.name, "add"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 2); + assert_eq!(tools[0].function.name, "create"); // Should extract last part + assert_eq!(tools[1].function.name, "add"); } #[tokio::test] @@ -51,9 +51,9 @@ async fn test_gpt_oss_with_assistant_prefix() { let input = r#"<|start|>assistant<|channel|>commentary to=functions.test<|constrain|>json<|message|>{"key": "value"}<|call|>"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "test"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "test"); } #[tokio::test] @@ -63,10 +63,10 @@ async fn test_gpt_oss_empty_args() { let input = r#"<|channel|>commentary to=functions.get_time<|constrain|>json<|message|>{}<|call|>"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "get_time"); - assert_eq!(result[0].function.arguments, "{}"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "get_time"); + assert_eq!(tools[0].function.arguments, "{}"); } #[tokio::test] @@ -127,9 +127,9 @@ async fn test_gpt_oss_with_whitespace() { let input = r#"<|channel|>commentary to=functions.test <|constrain|>json<|message|>{"key": "value"}<|call|>"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "test"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "test"); } #[tokio::test] @@ -145,11 +145,11 @@ async fn test_gpt_oss_complex_json() { } }<|call|>"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "process"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "process"); - let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert!(args["nested"]["data"].is_array()); assert_eq!(args["nested"]["config"]["enabled"], true); } @@ -161,9 +161,9 @@ async fn test_commentary_without_function() { // Python should extract commentary as normal text let input = r#"<|channel|>commentary<|message|>**Action plan**: 1. Do X 2. Do Y<|end|>"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 0); // No tool calls - // TODO: Verify normal text = "**Action plan**: 1. Do X 2. Do Y" + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 0); // No tool calls + // TODO: Verify normal text = "**Action plan**: 1. Do X 2. Do Y" } #[tokio::test] @@ -173,9 +173,9 @@ async fn test_final_channel() { let input = r#"<|channel|>commentary to=functions.test<|constrain|>json<|message|>{"x": 1}<|call|> <|channel|>final<|message|>The result is calculated.<|return|>"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "test"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "test"); // TODO: Verify normal text = "The result is calculated." } @@ -187,8 +187,8 @@ async fn test_mixed_commentary_and_calls() { <|channel|>commentary to=functions.calc<|constrain|>json<|message|>{"x": 5}<|call|> <|channel|>commentary<|message|>Processing...<|end|>"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "calc"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "calc"); // TODO: Verify normal text = "Let me think Processing..." } diff --git a/sgl-router/tests/tool_parser_json.rs b/sgl-router/tests/tool_parser_json.rs index c8c42b70f5f3..3ce715e08f48 100644 --- a/sgl-router/tests/tool_parser_json.rs +++ b/sgl-router/tests/tool_parser_json.rs @@ -10,11 +10,11 @@ async fn test_simple_json_tool_call() { let parser = JsonParser::new(); let input = r#"{"name": "get_weather", "arguments": {"location": "San Francisco"}}"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "get_weather"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "get_weather"); - let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert_eq!(args["location"], "San Francisco"); } @@ -26,10 +26,10 @@ async fn test_json_array_of_tools() { {"name": "search", "arguments": {"query": "news"}} ]"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 2); - assert_eq!(result[0].function.name, "get_weather"); - assert_eq!(result[1].function.name, "search"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 2); + assert_eq!(tools[0].function.name, "get_weather"); + assert_eq!(tools[1].function.name, "search"); } #[tokio::test] @@ -37,11 +37,11 @@ async fn test_json_with_parameters_key() { let parser = JsonParser::new(); let input = r#"{"name": "calculate", "parameters": {"x": 10, "y": 20}}"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "calculate"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "calculate"); - let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert_eq!(args["x"], 10); assert_eq!(args["y"], 20); } @@ -51,9 +51,9 @@ async fn test_json_extraction_from_text() { let parser = JsonParser::new(); let input = r#"I'll help you with that. {"name": "search", "arguments": {"query": "rust"}} Let me search for that."#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "search"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "search"); } #[tokio::test] @@ -73,11 +73,11 @@ async fn test_json_with_nested_objects() { } }"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "update_config"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "update_config"); - let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert_eq!(args["settings"]["theme"], "dark"); assert_eq!(args["settings"]["notifications"]["email"], true); } @@ -87,10 +87,10 @@ async fn test_json_with_special_characters() { let parser = JsonParser::new(); let input = r#"{"name": "echo", "arguments": {"text": "Line 1\nLine 2\tTabbed", "path": "C:\\Users\\test"}}"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); - let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert_eq!(args["text"], "Line 1\nLine 2\tTabbed"); assert_eq!(args["path"], "C:\\Users\\test"); } @@ -100,10 +100,10 @@ async fn test_json_with_unicode() { let parser = JsonParser::new(); let input = r#"{"name": "translate", "arguments": {"text": "Hello 世界 🌍", "emoji": "😊"}}"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); - let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert_eq!(args["text"], "Hello 世界 🌍"); assert_eq!(args["emoji"], "😊"); } @@ -113,11 +113,11 @@ async fn test_json_empty_arguments() { let parser = JsonParser::new(); let input = r#"{"name": "ping", "arguments": {}}"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "ping"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "ping"); - let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert_eq!(args, json!({})); } @@ -127,13 +127,13 @@ async fn test_json_invalid_format() { // Missing closing brace let input = r#"{"name": "test", "arguments": {"key": "value""#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 0); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 0); // Not JSON at all let input = "This is just plain text"; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 0); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 0); } #[tokio::test] diff --git a/sgl-router/tests/tool_parser_kimik2.rs b/sgl-router/tests/tool_parser_kimik2.rs index b9cc65c11358..e3d1177ed426 100644 --- a/sgl-router/tests/tool_parser_kimik2.rs +++ b/sgl-router/tests/tool_parser_kimik2.rs @@ -12,11 +12,11 @@ async fn test_kimik2_complete_parsing() { <|tool_calls_section_end|> The weather in Tokyo is..."#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "get_weather"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "get_weather"); - let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert_eq!(args["location"], "Tokyo"); assert_eq!(args["units"], "celsius"); } @@ -30,10 +30,10 @@ async fn test_kimik2_multiple_tools() { <|tool_call_begin|>functions.translate:1<|tool_call_argument_begin|>{"text": "Hello", "to": "ja"}<|tool_call_end|> <|tool_calls_section_end|>"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 2); - assert_eq!(result[0].function.name, "search"); - assert_eq!(result[1].function.name, "translate"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 2); + assert_eq!(tools[0].function.name, "search"); + assert_eq!(tools[1].function.name, "translate"); } #[tokio::test] @@ -44,11 +44,11 @@ async fn test_kimik2_with_whitespace() { <|tool_call_begin|> functions.test:0 <|tool_call_argument_begin|> {"key": "value", "num": 42} <|tool_call_end|> <|tool_calls_section_end|>"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "test"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "test"); - let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert_eq!(args["key"], "value"); assert_eq!(args["num"], 42); } @@ -117,11 +117,11 @@ async fn test_kimik2_sequential_indices() { <|tool_call_begin|>functions.third:2<|tool_call_argument_begin|>{"param": "c"}<|tool_call_end|> <|tool_calls_section_end|>"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 3); - assert_eq!(result[0].function.name, "first"); - assert_eq!(result[1].function.name, "second"); - assert_eq!(result[2].function.name, "third"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 3); + assert_eq!(tools[0].function.name, "first"); + assert_eq!(tools[1].function.name, "second"); + assert_eq!(tools[2].function.name, "third"); } #[tokio::test] @@ -134,10 +134,10 @@ async fn test_function_index_extraction() { <|tool_call_begin|>functions.calc:1<|tool_call_argument_begin|>{"x": 10}<|tool_call_end|> <|tool_calls_section_end|>"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 2); - assert_eq!(result[0].function.name, "search"); - assert_eq!(result[1].function.name, "calc"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 2); + assert_eq!(tools[0].function.name, "search"); + assert_eq!(tools[1].function.name, "calc"); // TODO: Verify indices are preserved: 0 and 1 // TODO: Verify normal text = "Text before tool calls." } @@ -150,7 +150,7 @@ async fn test_namespace_extraction() { <|tool_call_begin|>api.tools.search:0<|tool_call_argument_begin|>{"q": "test"}<|tool_call_end|> <|tool_calls_section_end|>"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "search"); // Should extract after last dot + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "search"); // Should extract after last dot } diff --git a/sgl-router/tests/tool_parser_llama.rs b/sgl-router/tests/tool_parser_llama.rs index a0d1cff91046..18dd76e27a6d 100644 --- a/sgl-router/tests/tool_parser_llama.rs +++ b/sgl-router/tests/tool_parser_llama.rs @@ -9,11 +9,11 @@ async fn test_llama_python_tag_format() { let parser = LlamaParser::new(); let input = r#"<|python_tag|>{"name": "search", "arguments": {"query": "weather"}}"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "search"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "search"); - let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert_eq!(args["query"], "weather"); } @@ -22,11 +22,11 @@ async fn test_llama_plain_json_fallback() { let parser = LlamaParser::new(); let input = r#"{"name": "calculate", "arguments": {"x": 5, "y": 10}}"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "calculate"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "calculate"); - let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert_eq!(args["x"], 5); assert_eq!(args["y"], 10); } @@ -36,11 +36,11 @@ async fn test_llama_with_text_before() { let parser = LlamaParser::new(); let input = r#"Let me help you with that. <|python_tag|>{"name": "get_time", "arguments": {"timezone": "UTC"}}"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "get_time"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "get_time"); - let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert_eq!(args["timezone"], "UTC"); } @@ -58,11 +58,11 @@ async fn test_llama_with_nested_json() { } }"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "update_settings"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "update_settings"); - let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert_eq!(args["preferences"]["theme"], "dark"); assert_eq!(args["notifications"], true); } @@ -73,15 +73,15 @@ async fn test_llama_empty_arguments() { // With python_tag let input = r#"<|python_tag|>{"name": "ping", "arguments": {}}"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "ping"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "ping"); // Plain JSON let input = r#"{"name": "ping", "arguments": {}}"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "ping"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "ping"); } #[tokio::test] @@ -99,8 +99,8 @@ async fn test_llama_invalid_json_after_tag() { let parser = LlamaParser::new(); let input = r#"<|python_tag|>{"name": invalid}"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 0); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 0); } #[tokio::test] @@ -112,9 +112,9 @@ async fn test_llama_real_world_output() { <|python_tag|>{"name": "web_search", "arguments": {"query": "Llama 3.2 model capabilities", "num_results": 5, "search_type": "recent"}}"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "web_search"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "web_search"); let formatted_input = r#"<|python_tag|>{ "name": "get_current_time", @@ -124,9 +124,9 @@ async fn test_llama_real_world_output() { } }"#; - let result2 = parser.parse_complete(formatted_input).await.unwrap(); - assert_eq!(result2.len(), 1); - assert_eq!(result2[0].function.name, "get_current_time"); + let (_normal_text, tools2) = parser.parse_complete(formatted_input).await.unwrap(); + assert_eq!(tools2.len(), 1); + assert_eq!(tools2[0].function.name, "get_current_time"); } #[tokio::test] @@ -136,9 +136,9 @@ async fn test_llama_json_array_format() { // Plain JSON array (should work as fallback) let input = r#"[{"name": "func1", "arguments": {}}, {"name": "func2", "arguments": {}}]"#; - let result = parser.parse_complete(input).await.unwrap(); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); // Current implementation might handle this through JSON fallback - assert!(!result.is_empty()); + assert!(!tools.is_empty()); } #[tokio::test] @@ -146,11 +146,11 @@ async fn test_single_json() { let parser = LlamaParser::new(); let text = r#"{"name": "get_weather", "arguments": {"city": "Paris"}}"#; - let result = parser.parse_complete(text).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "get_weather"); + let (_normal_text, tools) = parser.parse_complete(text).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "get_weather"); - let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert_eq!(args["city"], "Paris"); } @@ -159,10 +159,10 @@ async fn test_multiple_json_with_separator() { let parser = LlamaParser::new(); let text = r#"<|python_tag|>{"name": "get_weather", "arguments": {"city": "Paris"}};{"name": "get_tourist_attractions", "arguments": {"city": "Paris"}}"#; - let result = parser.parse_complete(text).await.unwrap(); + let (_normal_text, tools) = parser.parse_complete(text).await.unwrap(); // Note: Current implementation may only parse the first one due to semicolon handling - assert!(!result.is_empty()); - assert_eq!(result[0].function.name, "get_weather"); + assert!(!tools.is_empty()); + assert_eq!(tools[0].function.name, "get_weather"); } #[tokio::test] @@ -170,10 +170,10 @@ async fn test_multiple_json_with_separator_customized() { let parser = LlamaParser::new(); let text = r#"<|python_tag|>{"name": "get_weather", "arguments": {}}<|python_tag|>{"name": "get_tourist_attractions", "arguments": {}}"#; - let result = parser.parse_complete(text).await.unwrap(); + let (_normal_text, tools) = parser.parse_complete(text).await.unwrap(); // Current implementation may handle this differently - assert!(!result.is_empty()); - assert_eq!(result[0].function.name, "get_weather"); + assert!(!tools.is_empty()); + assert_eq!(tools[0].function.name, "get_weather"); } #[tokio::test] @@ -181,9 +181,9 @@ async fn test_json_with_trailing_text() { let parser = LlamaParser::new(); let text = r#"{"name": "get_weather", "arguments": {}} Some follow-up text"#; - let result = parser.parse_complete(text).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "get_weather"); + let (_normal_text, tools) = parser.parse_complete(text).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "get_weather"); } #[tokio::test] @@ -191,10 +191,10 @@ async fn test_invalid_then_valid_json() { let parser = LlamaParser::new(); let text = r#"{"name": "get_weather", "arguments": {{"name": "get_weather", "arguments": {}}"#; - let result = parser.parse_complete(text).await.unwrap(); + let (_normal_text, tools) = parser.parse_complete(text).await.unwrap(); // Should parse at least one valid JSON - if !result.is_empty() { - assert_eq!(result[0].function.name, "get_weather"); + if !tools.is_empty() { + assert_eq!(tools[0].function.name, "get_weather"); } } @@ -203,8 +203,8 @@ async fn test_plain_text_only() { let parser = LlamaParser::new(); let text = "This is just plain explanation text."; - let result = parser.parse_complete(text).await.unwrap(); - assert_eq!(result.len(), 0); + let (_normal_text, tools) = parser.parse_complete(text).await.unwrap(); + assert_eq!(tools.len(), 0); } #[tokio::test] @@ -212,9 +212,9 @@ async fn test_with_python_tag_prefix() { let parser = LlamaParser::new(); let text = r#"Some intro. <|python_tag|>{"name": "get_weather", "arguments": {}}"#; - let result = parser.parse_complete(text).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "get_weather"); + let (_normal_text, tools) = parser.parse_complete(text).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "get_weather"); } // STREAMING TESTS diff --git a/sgl-router/tests/tool_parser_mistral.rs b/sgl-router/tests/tool_parser_mistral.rs index 3801006f57e9..b3a3c3c877cb 100644 --- a/sgl-router/tests/tool_parser_mistral.rs +++ b/sgl-router/tests/tool_parser_mistral.rs @@ -11,11 +11,11 @@ async fn test_mistral_single_tool() { let input = r#"Let me search for that. [TOOL_CALLS] [{"name": "search_web", "arguments": {"query": "latest news", "max_results": 5}}]"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "search_web"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "search_web"); - let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert_eq!(args["query"], "latest news"); assert_eq!(args["max_results"], 5); } @@ -29,15 +29,15 @@ async fn test_mistral_multiple_tools() { {"name": "search_news", "arguments": {"query": "AI developments", "limit": 10}} ]"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 2); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 2); - assert_eq!(result[0].function.name, "get_weather"); - let args0: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + assert_eq!(tools[0].function.name, "get_weather"); + let args0: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert_eq!(args0["city"], "Tokyo"); - assert_eq!(result[1].function.name, "search_news"); - let args1: serde_json::Value = serde_json::from_str(&result[1].function.arguments).unwrap(); + assert_eq!(tools[1].function.name, "search_news"); + let args1: serde_json::Value = serde_json::from_str(&tools[1].function.arguments).unwrap(); assert_eq!(args1["query"], "AI developments"); } @@ -47,10 +47,10 @@ async fn test_mistral_nested_json() { let input = r#"Processing complex data. [TOOL_CALLS] [{"name": "process_data", "arguments": {"config": {"nested": {"value": [1, 2, 3]}}, "enabled": true}}]"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); - let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert_eq!(args["config"]["nested"]["value"], json!([1, 2, 3])); assert_eq!(args["enabled"], true); } @@ -62,9 +62,9 @@ async fn test_mistral_with_text_after() { And here's some text after the tool call that should be ignored."#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "test"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "test"); } #[tokio::test] @@ -72,9 +72,9 @@ async fn test_mistral_empty_arguments() { let parser = MistralParser::new(); let input = r#"[TOOL_CALLS] [{"name": "ping", "arguments": {}}]"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "ping"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "ping"); } #[tokio::test] @@ -82,10 +82,10 @@ async fn test_mistral_with_brackets_in_strings() { let parser = MistralParser::new(); let input = r#"[TOOL_CALLS] [{"name": "echo", "arguments": {"text": "Array notation: arr[0] = value[1]"}}]"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); - let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert_eq!(args["text"], "Array notation: arr[0] = value[1]"); } @@ -105,15 +105,15 @@ async fn test_mistral_malformed_json() { // Missing closing bracket let input = r#"[TOOL_CALLS] [{"name": "test", "arguments": {}"#; - if let Ok(result) = parser.parse_complete(input).await { - assert_eq!(result.len(), 0); + if let Ok((_normal_text, tools)) = parser.parse_complete(input).await { + assert_eq!(tools.len(), 0); } // Error is also acceptable for malformed input // Invalid JSON inside let input = r#"[TOOL_CALLS] [{"name": invalid}]"#; - if let Ok(result) = parser.parse_complete(input).await { - assert_eq!(result.len(), 0); + if let Ok((_normal_text, tools)) = parser.parse_complete(input).await { + assert_eq!(tools.len(), 0); } // Error is also acceptable for malformed input } @@ -146,8 +146,8 @@ async fn test_mistral_real_world_output() { Let me execute these searches for you."#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 2); - assert_eq!(result[0].function.name, "web_search"); - assert_eq!(result[1].function.name, "get_weather"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 2); + assert_eq!(tools[0].function.name, "web_search"); + assert_eq!(tools[1].function.name, "get_weather"); } diff --git a/sgl-router/tests/tool_parser_mixed_edge_cases.rs b/sgl-router/tests/tool_parser_mixed_edge_cases.rs index 38f086c53d12..b13eba2a307e 100644 --- a/sgl-router/tests/tool_parser_mixed_edge_cases.rs +++ b/sgl-router/tests/tool_parser_mixed_edge_cases.rs @@ -17,9 +17,9 @@ async fn test_mixed_formats_in_text() { But here's the actual JSON: {"name": "test", "arguments": {}} "#; - let result = json_parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "test"); + let (_normal_text, tools) = json_parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "test"); // Mistral parser should ignore JSON and other formats let mistral_parser = MistralParser::new(); @@ -28,9 +28,9 @@ async fn test_mixed_formats_in_text() { [TOOL_CALLS] [{"name": "real", "arguments": {}}] "#; - let result = mistral_parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "real"); + let (_normal_text, tools) = mistral_parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "real"); } #[tokio::test] @@ -38,9 +38,9 @@ async fn test_format_markers_in_string_content() { let pythonic_parser = PythonicParser::new(); let input = r#"[echo(text="Use [TOOL_CALLS] and in text")]"#; - let result = pythonic_parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let (_normal_text, tools) = pythonic_parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert_eq!(args["text"], "Use [TOOL_CALLS] and in text"); let qwen_parser = QwenParser::new(); @@ -48,9 +48,9 @@ async fn test_format_markers_in_string_content() { {"name": "log", "arguments": {"msg": "Found [function()] pattern"}} "#; - let result = qwen_parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let (_normal_text, tools) = qwen_parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert_eq!(args["msg"], "Found [function()] pattern"); } @@ -75,11 +75,11 @@ async fn test_deeply_nested_json_structures() { } }"#; - let result = json_parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "deep_process"); + let (_normal_text, tools) = json_parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "deep_process"); - let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert!(args["level1"]["level2"]["level3"]["level4"]["level5"]["data"].is_array()); } @@ -93,14 +93,14 @@ async fn test_multiple_sequential_calls_different_formats() { // Llama parser currently only returns the first tool found let input = r#"First call: <|python_tag|>{"name": "call1", "arguments": {}}"#; - let result = llama_parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "call1"); + let (_normal_text, tools) = llama_parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "call1"); let input2 = r#"{"name": "call2", "arguments": {"x": 1}}"#; - let result2 = llama_parser.parse_complete(input2).await.unwrap(); - assert_eq!(result2.len(), 1); - assert_eq!(result2[0].function.name, "call2"); + let (_normal_text2, tools2) = llama_parser.parse_complete(input2).await.unwrap(); + assert_eq!(tools2.len(), 1); + assert_eq!(tools2[0].function.name, "call2"); } #[tokio::test] @@ -119,8 +119,8 @@ async fn test_empty_and_whitespace_variations() { ]; for input in cases { - let result = json_parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1, "Should parse regardless of whitespace"); + let (_normal_text, tools) = json_parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1, "Should parse regardless of whitespace"); } } @@ -141,11 +141,11 @@ async fn test_special_json_values() { } }"#; - let result = json_parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "test_special"); + let (_normal_text, tools) = json_parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "test_special"); - let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert!(args["special_strings"].is_array()); assert!(args["escaped"].is_string()); } @@ -181,22 +181,22 @@ async fn test_boundary_cases_for_extraction() { // JSON at the very beginning let input = r#"{"name": "start", "arguments": {}} and then text"#; - let result = json_parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "start"); + let (_normal_text, tools) = json_parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "start"); // JSON at the very end let input = r#"Some text first {"name": "end", "arguments": {}}"#; - let result = json_parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "end"); + let (_normal_text, tools) = json_parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "end"); // Multiple JSON objects in text (should find first valid one) let input = r#"Text {"name": "first", "arguments": {}} more {"name": "second", "arguments": {}}"#; - let result = json_parser.parse_complete(input).await.unwrap(); - assert!(!result.is_empty()); - assert_eq!(result[0].function.name, "first"); + let (_normal_text, tools) = json_parser.parse_complete(input).await.unwrap(); + assert!(!tools.is_empty()); + assert_eq!(tools[0].function.name, "first"); } #[tokio::test] @@ -205,15 +205,15 @@ async fn test_pythonic_edge_cases() { // Function name with underscores and numbers let input = r#"[func_name_2(param_1="value")]"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "func_name_2"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "func_name_2"); // Empty string argument let input = r#"[process(text="")]"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert_eq!(args["text"], ""); } @@ -238,11 +238,11 @@ async fn test_mistral_with_pretty_json() { } ]"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "formatted"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "formatted"); - let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert_eq!(args["nested"]["key"], "value"); assert_eq!(args["array"], json!([1, 2, 3])); } @@ -256,11 +256,11 @@ async fn test_qwen_with_cdata_like_content() { {"name": "process", "arguments": {"xml": ""}} "#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "process"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "process"); - let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert_eq!(args["xml"], ""); } @@ -271,9 +271,9 @@ async fn test_extremely_long_function_names() { let long_name = "very_long_function_name_that_might_appear_in_generated_code_somewhere"; let input = format!(r#"[{}(param="value")]"#, long_name); - let result = parser.parse_complete(&input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, long_name); + let (_normal_text, tools) = parser.parse_complete(&input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, long_name); } #[tokio::test] @@ -283,10 +283,10 @@ async fn test_json_with_duplicate_keys() { // JSON with duplicate keys (last one should win per JSON spec) let input = r#"{"name": "test", "arguments": {"key": "first", "key": "second"}}"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); - let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); // JSON parsers typically keep the last value for duplicate keys assert_eq!(args["key"], "second"); } diff --git a/sgl-router/tests/tool_parser_pythonic.rs b/sgl-router/tests/tool_parser_pythonic.rs index c8612276017a..87e3d927cde5 100644 --- a/sgl-router/tests/tool_parser_pythonic.rs +++ b/sgl-router/tests/tool_parser_pythonic.rs @@ -10,11 +10,11 @@ async fn test_pythonic_single_function() { let parser = PythonicParser::new(); let input = r#"[get_weather(city="London", units="celsius")]"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "get_weather"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "get_weather"); - let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert_eq!(args["city"], "London"); assert_eq!(args["units"], "celsius"); } @@ -25,12 +25,12 @@ async fn test_pythonic_multiple_functions() { let input = r#"[search_web(query="Rust programming", max_results=5), get_time(timezone="UTC")]"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 2); - assert_eq!(result[0].function.name, "search_web"); - assert_eq!(result[1].function.name, "get_time"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 2); + assert_eq!(tools[0].function.name, "search_web"); + assert_eq!(tools[1].function.name, "get_time"); - let args0: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args0: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert_eq!(args0["query"], "Rust programming"); assert_eq!(args0["max_results"], 5); } @@ -40,10 +40,10 @@ async fn test_pythonic_with_python_literals() { let parser = PythonicParser::new(); let input = r#"[configure(enabled=True, disabled=False, optional=None)]"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); - let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert_eq!(args["enabled"], true); assert_eq!(args["disabled"], false); assert_eq!(args["optional"], json!(null)); @@ -55,10 +55,10 @@ async fn test_pythonic_with_lists_and_dicts() { let input = r#"[process_data(items=[1, 2, 3], config={"key": "value", "nested": {"deep": True}})]"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); - let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert_eq!(args["items"], json!([1, 2, 3])); assert_eq!(args["config"]["key"], "value"); assert_eq!(args["config"]["nested"]["deep"], true); @@ -71,11 +71,11 @@ async fn test_pythonic_with_special_tokens() { // Llama 4 sometimes outputs these tokens let input = r#"<|python_start|>[calculate(x=10, y=20)]<|python_end|>"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "calculate"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "calculate"); - let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert_eq!(args["x"], 10); assert_eq!(args["y"], 20); } @@ -85,10 +85,10 @@ async fn test_pythonic_with_nested_parentheses() { let parser = PythonicParser::new(); let input = r#"[math_eval(expression="(2 + 3) * (4 - 1)", round_to=2)]"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); - let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert_eq!(args["expression"], "(2 + 3) * (4 - 1)"); assert_eq!(args["round_to"], 2); } @@ -98,10 +98,10 @@ async fn test_pythonic_with_escaped_quotes() { let parser = PythonicParser::new(); let input = r#"[echo(text="She said \"Hello\" to him")]"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); - let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert_eq!(args["text"], "She said \"Hello\" to him"); } @@ -110,11 +110,11 @@ async fn test_pythonic_empty_arguments() { let parser = PythonicParser::new(); let input = r#"[ping()]"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "ping"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "ping"); - let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert_eq!(args, json!({})); } @@ -135,8 +135,8 @@ async fn test_pythonic_invalid_syntax() { // Missing closing bracket let input = r#"[function(arg=value"#; - if let Ok(result) = parser.parse_complete(input).await { - assert_eq!(result.len(), 0); + if let Ok((_normal_text, tools)) = parser.parse_complete(input).await { + assert_eq!(tools.len(), 0); } // Error is also acceptable for invalid syntax @@ -144,10 +144,10 @@ async fn test_pythonic_invalid_syntax() { // Note: The parser currently accepts this invalid syntax and returns a result // This is a known limitation of the current implementation let input = r#"[function(=value)]"#; - if let Ok(result) = parser.parse_complete(input).await { + if let Ok((_normal_text, tools)) = parser.parse_complete(input).await { // The parser incorrectly accepts this, returning 1 result // We'll accept this behavior for now but note it's not ideal - assert!(result.len() <= 1, "Should parse at most one function"); + assert!(tools.len() <= 1, "Should parse at most one function"); } // Error would be the correct behavior } @@ -165,13 +165,13 @@ async fn test_pythonic_real_world_llama4() { These functions will provide the information you need."#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 3); - assert_eq!(result[0].function.name, "web_search"); - assert_eq!(result[1].function.name, "calculate"); - assert_eq!(result[2].function.name, "get_weather"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 3); + assert_eq!(tools[0].function.name, "web_search"); + assert_eq!(tools[1].function.name, "calculate"); + assert_eq!(tools[2].function.name, "get_weather"); - let args0: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args0: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert_eq!(args0["query"], "latest Rust features"); assert_eq!(args0["safe_search"], true); } @@ -182,11 +182,11 @@ async fn test_pythonic_nested_brackets_in_lists() { let input = r#"[process_matrix(data=[[1, 2], [3, 4]], labels=["row[0]", "row[1]"])]"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "process_matrix"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "process_matrix"); - let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert_eq!(args["data"], json!([[1, 2], [3, 4]])); assert_eq!(args["labels"], json!(["row[0]", "row[1]"])); } @@ -198,11 +198,11 @@ async fn test_pythonic_nested_brackets_in_dicts() { let input = r#"[analyze(config={"patterns": ["[a-z]+", "[0-9]+"], "nested": {"list": [1, [2, 3]]}})]"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "analyze"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "analyze"); - let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert_eq!(args["config"]["patterns"], json!(["[a-z]+", "[0-9]+"])); assert_eq!(args["config"]["nested"]["list"], json!([1, [2, 3]])); } @@ -213,11 +213,11 @@ async fn test_pythonic_mixed_quotes() { let input = r#"[format_text(single='Hello', double="World", mixed="It's \"quoted\"")]"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "format_text"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "format_text"); - let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert_eq!(args["single"], "Hello"); assert_eq!(args["double"], "World"); assert_eq!(args["mixed"], "It's \"quoted\""); @@ -233,11 +233,11 @@ async fn test_pythonic_complex_nesting() { metadata={"tags": ["nested[0]", "nested[1]"], "config": {"depth": [1, 2, 3]}} )]"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "transform"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "transform"); - let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert!(args["matrix"].is_array()); assert!(args["operations"].is_array()); assert_eq!(args["operations"][0]["type"], "scale"); @@ -530,12 +530,12 @@ async fn test_detect_and_parse_with_python_start_and_end_token() { let parser = PythonicParser::new(); let text = "User wants to get the weather in Mars. <|python_start|>[get_weather(location='Mars', unit='celsius')]<|python_end|> In this way we will get the weather in Mars."; - let result = parser.parse_complete(text).await.unwrap(); + let (_normal_text, tools) = parser.parse_complete(text).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "get_weather"); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "get_weather"); - let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert_eq!(args["location"], "Mars"); assert_eq!(args["unit"], "celsius"); } diff --git a/sgl-router/tests/tool_parser_qwen.rs b/sgl-router/tests/tool_parser_qwen.rs index 9bad102ae66f..3f733d3d1dc0 100644 --- a/sgl-router/tests/tool_parser_qwen.rs +++ b/sgl-router/tests/tool_parser_qwen.rs @@ -12,11 +12,11 @@ async fn test_qwen_single_tool() { {"name": "get_weather", "arguments": {"city": "Beijing", "units": "celsius"}} "#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "get_weather"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "get_weather"); - let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert_eq!(args["city"], "Beijing"); assert_eq!(args["units"], "celsius"); } @@ -32,10 +32,10 @@ async fn test_qwen_multiple_sequential_tools() { {"name": "translate", "arguments": {"text": "Hello", "to": "zh"}} "#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 2); - assert_eq!(result[0].function.name, "search"); - assert_eq!(result[1].function.name, "translate"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 2); + assert_eq!(tools[0].function.name, "search"); + assert_eq!(tools[1].function.name, "translate"); } #[tokio::test] @@ -55,11 +55,11 @@ async fn test_qwen_pretty_printed_json() { } "#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "create_document"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "create_document"); - let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert_eq!(args["metadata"]["author"], "Qwen"); assert_eq!(args["metadata"]["tags"], json!(["test", "example"])); } @@ -79,10 +79,10 @@ Now I'll translate something. Done!"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 2); - assert_eq!(result[0].function.name, "search"); - assert_eq!(result[1].function.name, "translate"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 2); + assert_eq!(tools[0].function.name, "search"); + assert_eq!(tools[1].function.name, "translate"); } #[tokio::test] @@ -92,9 +92,9 @@ async fn test_qwen_empty_arguments() { {"name": "get_time", "arguments": {}} "#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "get_time"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "get_time"); } #[tokio::test] @@ -104,10 +104,10 @@ async fn test_qwen_with_newlines_in_strings() { {"name": "write_file", "arguments": {"content": "Line 1\nLine 2\nLine 3", "path": "/tmp/test.txt"}} "#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); - let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert_eq!(args["content"], "Line 1\nLine 2\nLine 3"); } @@ -128,14 +128,14 @@ async fn test_qwen_incomplete_tags() { // Missing closing tag let input = r#" {"name": "test", "arguments": {}}"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 0); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 0); // Missing opening tag let input = r#"{"name": "test", "arguments": {}} "#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 0); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 0); } #[tokio::test] @@ -171,12 +171,12 @@ Let me also calculate something for you: These tools will provide the information you need."#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 2); - assert_eq!(result[0].function.name, "web_search"); - assert_eq!(result[1].function.name, "calculator"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 2); + assert_eq!(tools[0].function.name, "web_search"); + assert_eq!(tools[1].function.name, "calculator"); - let args0: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args0: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert_eq!(args0["query"], "quantum computing breakthroughs 2024"); assert_eq!(args0["safe_search"], true); } diff --git a/sgl-router/tests/tool_parser_registry.rs b/sgl-router/tests/tool_parser_registry.rs index afe655bdeac1..52cfed81c836 100644 --- a/sgl-router/tests/tool_parser_registry.rs +++ b/sgl-router/tests/tool_parser_registry.rs @@ -24,9 +24,9 @@ async fn test_openai_models_use_json() { for model in models { let parser = registry.get_parser(model).unwrap(); let test_input = r#"{"name": "test", "arguments": {}}"#; - let result = parser.parse_complete(test_input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "test"); + let (_normal_text, tools) = parser.parse_complete(test_input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "test"); } } @@ -38,8 +38,8 @@ async fn test_anthropic_models_use_json() { for model in models { let parser = registry.get_parser(model).unwrap(); let test_input = r#"{"name": "test", "arguments": {}}"#; - let result = parser.parse_complete(test_input).await.unwrap(); - assert_eq!(result.len(), 1); + let (_normal_text, tools) = parser.parse_complete(test_input).await.unwrap(); + assert_eq!(tools.len(), 1); } } @@ -51,9 +51,9 @@ async fn test_mistral_models() { for model in models { let parser = registry.get_parser(model).unwrap(); let test_input = r#"[TOOL_CALLS] [{"name": "test", "arguments": {}}]"#; - let result = parser.parse_complete(test_input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "test"); + let (_normal_text, tools) = parser.parse_complete(test_input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "test"); } } @@ -67,9 +67,9 @@ async fn test_qwen_models() { let test_input = r#" {"name": "test", "arguments": {}} "#; - let result = parser.parse_complete(test_input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "test"); + let (_normal_text, tools) = parser.parse_complete(test_input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "test"); } } @@ -80,22 +80,22 @@ async fn test_llama_model_variants() { // Llama 4 uses pythonic let parser = registry.get_parser("llama-4-70b").unwrap(); let test_input = r#"[get_weather(city="NYC")]"#; - let result = parser.parse_complete(test_input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "get_weather"); + let (_normal_text, tools) = parser.parse_complete(test_input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "get_weather"); // Llama 3.2 uses python_tag let parser = registry.get_parser("llama-3.2-8b").unwrap(); let test_input = r#"<|python_tag|>{"name": "test", "arguments": {}}"#; - let result = parser.parse_complete(test_input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "test"); + let (_normal_text, tools) = parser.parse_complete(test_input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "test"); // Other Llama models use JSON let parser = registry.get_parser("llama-2-70b").unwrap(); let test_input = r#"{"name": "test", "arguments": {}}"#; - let result = parser.parse_complete(test_input).await.unwrap(); - assert_eq!(result.len(), 1); + let (_normal_text, tools) = parser.parse_complete(test_input).await.unwrap(); + assert_eq!(tools.len(), 1); } #[tokio::test] @@ -105,9 +105,9 @@ async fn test_deepseek_models() { // DeepSeek uses pythonic format (simplified, v3 would need custom parser) let parser = registry.get_parser("deepseek-coder").unwrap(); let test_input = r#"[function(arg="value")]"#; - let result = parser.parse_complete(test_input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "function"); + let (_normal_text, tools) = parser.parse_complete(test_input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "function"); } #[tokio::test] @@ -117,9 +117,9 @@ async fn test_unknown_model_fallback() { // Unknown models should fall back to JSON parser let parser = registry.get_parser("unknown-model-xyz").unwrap(); let test_input = r#"{"name": "fallback", "arguments": {}}"#; - let result = parser.parse_complete(test_input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "fallback"); + let (_normal_text, tools) = parser.parse_complete(test_input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "fallback"); } #[tokio::test] @@ -181,10 +181,10 @@ The weather information has been requested."#, for (model, output, expected_name) in test_cases { let parser = registry.get_parser(model).unwrap(); - let result = parser.parse_complete(output).await.unwrap(); - assert!(!result.is_empty(), "No tools parsed for model {}", model); + let (_normal_text, tools) = parser.parse_complete(output).await.unwrap(); + assert!(!tools.is_empty(), "No tools parsed for model {}", model); assert_eq!( - result[0].function.name, expected_name, + tools[0].function.name, expected_name, "Wrong function name for model {}", model ); diff --git a/sgl-router/tests/tool_parser_step3.rs b/sgl-router/tests/tool_parser_step3.rs index 681526d1606d..07ec32e787e7 100644 --- a/sgl-router/tests/tool_parser_step3.rs +++ b/sgl-router/tests/tool_parser_step3.rs @@ -15,11 +15,11 @@ async fn test_step3_complete_parsing() { <|tool_calls_end|> Here are the results..."#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "search"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "search"); - let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert_eq!(args["query"], "rust programming"); assert_eq!(args["limit"], 10); } @@ -38,10 +38,10 @@ async fn test_step3_multiple_tools() { <|tool_call_end|> <|tool_calls_end|>"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 2); - assert_eq!(result[0].function.name, "get_weather"); - assert_eq!(result[1].function.name, "get_news"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 2); + assert_eq!(tools[0].function.name, "get_weather"); + assert_eq!(tools[1].function.name, "get_news"); } #[tokio::test] @@ -58,10 +58,10 @@ async fn test_step3_type_conversion() { <|tool_call_end|> <|tool_calls_end|>"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); - let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert_eq!(args["count"], 100); assert_eq!(args["rate"], 2.5); assert_eq!(args["active"], true); @@ -132,11 +132,11 @@ async fn test_step3_nested_steptml() { <|tool_call_end|> <|tool_calls_end|>"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "config"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "config"); - let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert!(args["settings"].is_object()); assert!(args["array"].is_array()); } @@ -153,10 +153,10 @@ async fn test_step3_python_literals() { <|tool_call_end|> <|tool_calls_end|>"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); - let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert_eq!(args["bool_true"], true); assert_eq!(args["bool_false"], false); assert_eq!(args["none_value"], serde_json::Value::Null); @@ -174,11 +174,11 @@ async fn test_steptml_format() { <|tool_call_end|> <|tool_calls_end|>Text after."#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "search"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "search"); - let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert_eq!(args["query"], "rust lang"); assert_eq!(args["limit"], 10); // TODO: Verify normal text extraction @@ -195,10 +195,10 @@ async fn test_json_parameter_values() { <|tool_call_end|> <|tool_calls_end|>"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); - let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert!(args["settings"].is_object()); assert!(args["items"].is_array()); } @@ -214,11 +214,11 @@ async fn test_step3_parameter_with_angle_brackets() { <|tool_call_end|> <|tool_calls_end|>"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "compare"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "compare"); - let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert_eq!(args["expression"], "a < b && b > c"); assert_eq!(args["context"], "comparison test"); } @@ -233,6 +233,6 @@ async fn test_step3_empty_function_name() { <|tool_call_end|> <|tool_calls_end|>"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 0); // Should reject empty function name + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 0); // Should reject empty function name } diff --git a/sgl-router/tests/tool_parser_wrapper_tokens.rs b/sgl-router/tests/tool_parser_wrapper_tokens.rs index 9cdb77ab6e53..668b96fb5001 100644 --- a/sgl-router/tests/tool_parser_wrapper_tokens.rs +++ b/sgl-router/tests/tool_parser_wrapper_tokens.rs @@ -15,11 +15,11 @@ async fn test_json_with_xml_style_wrapper() { let input = r#"Some text before {"name": "test", "arguments": {"x": 1}} and after"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "test"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "test"); - let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert_eq!(args["x"], 1); } @@ -32,14 +32,14 @@ async fn test_json_with_multiple_wrapper_pairs() { }); let input1 = r#"{"name": "tool1", "arguments": {}}"#; - let result1 = parser.parse_complete(input1).await.unwrap(); - assert_eq!(result1.len(), 1); - assert_eq!(result1[0].function.name, "tool1"); + let (_normal_text, tools1) = parser.parse_complete(input1).await.unwrap(); + assert_eq!(tools1.len(), 1); + assert_eq!(tools1[0].function.name, "tool1"); let input2 = r#"<>{"name": "tool2", "arguments": {}}<>"#; - let result2 = parser.parse_complete(input2).await.unwrap(); - assert_eq!(result2.len(), 1); - assert_eq!(result2[0].function.name, "tool2"); + let (_normal_text, tools2) = parser.parse_complete(input2).await.unwrap(); + assert_eq!(tools2.len(), 1); + assert_eq!(tools2[0].function.name, "tool2"); } #[tokio::test] @@ -52,9 +52,9 @@ async fn test_json_with_only_start_token() { let input = r#"Some preamble >>>FUNCTION:{"name": "execute", "arguments": {"cmd": "ls"}}"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "execute"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "execute"); } #[tokio::test] @@ -68,9 +68,9 @@ async fn test_json_with_custom_separator() { // Though we're not testing multiple tools here, the separator is configured let input = r#"[FUNC]{"name": "test", "arguments": {}}[/FUNC]"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "test"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "test"); } #[tokio::test] @@ -88,21 +88,21 @@ async fn test_json_with_nested_wrapper_tokens_in_content() { let input = r#"{"name": "echo", "arguments": {"text": "Use and tags"}}"#; - let result = parser.parse_complete(input).await.unwrap(); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); // This is a known limitation - the parser may fail when end tokens appear in content // For now, we accept this behavior - if result.is_empty() { + if tools.is_empty() { // Parser failed due to nested tokens - this is expected assert_eq!( - result.len(), + tools.len(), 0, "Known limitation: nested wrapper tokens in content" ); } else { // If it does parse, verify it's correct - assert_eq!(result[0].function.name, "echo"); - let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + assert_eq!(tools[0].function.name, "echo"); + let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert_eq!(args["text"], "Use and tags"); } } @@ -118,9 +118,9 @@ async fn test_json_extraction_without_wrapper_tokens() { And here is some text after. "#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "search"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "search"); } #[tokio::test] @@ -143,9 +143,9 @@ async fn test_json_with_multiline_wrapper_content() { ``` Done!"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "format_code"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "format_code"); } #[tokio::test] @@ -158,11 +158,11 @@ async fn test_json_with_special_chars_in_tokens() { let input = r#"{{FUNC[[{"name": "test", "arguments": {"special": "[]{}"}}]]FUNC}}"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "test"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "test"); - let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert_eq!(args["special"], "[]{}"); } @@ -183,9 +183,9 @@ async fn test_json_multiple_tools_with_wrapper() { // Current implementation might handle this as separate calls // Let's test that at least the first one is parsed - let result = parser.parse_complete(input).await.unwrap(); - assert!(!result.is_empty(), "Should parse at least one tool"); - assert_eq!(result[0].function.name, "tool1"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert!(!tools.is_empty(), "Should parse at least one tool"); + assert_eq!(tools[0].function.name, "tool1"); } #[tokio::test] @@ -201,10 +201,10 @@ async fn test_json_wrapper_with_array() { {"name": "func2", "arguments": {"param": "value"}} ]"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 2); - assert_eq!(result[0].function.name, "func1"); - assert_eq!(result[1].function.name, "func2"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 2); + assert_eq!(tools[0].function.name, "func1"); + assert_eq!(tools[1].function.name, "func2"); } #[tokio::test] @@ -217,13 +217,13 @@ async fn test_json_incomplete_wrapper_tokens() { // Missing end token let input = r#"{"name": "test", "arguments": {}}"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 0, "Should not parse without closing token"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 0, "Should not parse without closing token"); // Missing start token let input = r#"{"name": "test", "arguments": {}}"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 0, "Should not parse without opening token"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 0, "Should not parse without opening token"); } #[tokio::test] @@ -236,7 +236,7 @@ async fn test_json_empty_wrapper_tokens() { let input = r#"{"name": "test", "arguments": {"key": "value"}}"#; - let result = parser.parse_complete(input).await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].function.name, "test"); + let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].function.name, "test"); }