diff --git a/sgl-router/Cargo.toml b/sgl-router/Cargo.toml index 845a1827c293..9500943173b0 100644 --- a/sgl-router/Cargo.toml +++ b/sgl-router/Cargo.toml @@ -57,7 +57,7 @@ tokio-stream = { version = "0.1", features = ["sync"] } anyhow = "1.0" tokenizers = { version = "0.22.0" } tiktoken-rs = { version = "0.7.0" } -minijinja = { version = "2.0", features = ["unstable_machinery"] } +minijinja = { version = "2.0", features = ["unstable_machinery", "json", "builtins"] } rustls = { version = "0.23", default-features = false, features = ["ring", "std"] } hf-hub = { version = "0.4.3", features = ["tokio"] } rmcp = { version = "0.6.3", features = ["client", "server", diff --git a/sgl-router/src/tool_parser/parsers/json_parser.rs b/sgl-router/src/tool_parser/parsers/json_parser.rs index cca506e9dd22..660e113e042d 100644 --- a/sgl-router/src/tool_parser/parsers/json_parser.rs +++ b/sgl-router/src/tool_parser/parsers/json_parser.rs @@ -174,11 +174,6 @@ impl JsonParser { Ok(tools) } - - /// Check if text contains tool calls - fn has_tool_call(&self, text: &str) -> bool { - text.contains('[') || text.contains('{') - } } impl Default for JsonParser { @@ -216,7 +211,7 @@ impl ToolParser for JsonParser { let current_text = &self.buffer.clone(); // Check if current_text has tool_call - let has_tool_start = self.has_tool_call(current_text) + let has_tool_start = self.has_tool_markers(current_text) || (self.current_tool_id >= 0 && current_text.starts_with(self.tool_call_separator)); if !has_tool_start { @@ -263,7 +258,7 @@ impl ToolParser for JsonParser { fn has_tool_markers(&self, text: &str) -> bool { let trimmed = text.trim(); - (trimmed.starts_with('[') || trimmed.starts_with('{')) && trimmed.contains(r#""name""#) + trimmed.starts_with('[') || trimmed.starts_with('{') } fn get_unstreamed_tool_args(&self) -> Option> { diff --git a/sgl-router/src/tool_parser/parsers/llama_parser.rs b/sgl-router/src/tool_parser/parsers/llama_parser.rs index 5836e9b554ac..14c09c1c5e04 100644 --- a/sgl-router/src/tool_parser/parsers/llama_parser.rs +++ b/sgl-router/src/tool_parser/parsers/llama_parser.rs @@ -121,11 +121,6 @@ impl LlamaParser { Ok(all_tools) } - - /// Check if text has tool call - fn has_tool_call(&self, text: &str) -> bool { - text.contains("<|python_tag|>") || text.contains('{') - } } impl Default for LlamaParser { @@ -184,7 +179,7 @@ impl ToolParser for LlamaParser { let current_text = &self.buffer.clone(); // Check if current_text has tool_call - let has_tool_start = self.has_tool_call(current_text) + let has_tool_start = self.has_tool_markers(current_text) || (self.current_tool_id >= 0 && current_text.starts_with(self.tool_call_separator)); if !has_tool_start { @@ -230,8 +225,7 @@ impl ToolParser for LlamaParser { fn has_tool_markers(&self, text: &str) -> bool { // Llama format if contains python_tag or starts with JSON object - text.contains("<|python_tag|>") - || (text.trim_start().starts_with('{') && text.contains(r#""name""#)) + text.contains("<|python_tag|>") || text.trim_start().starts_with('{') } fn get_unstreamed_tool_args(&self) -> Option> { diff --git a/sgl-router/tests/tool_parser_llama.rs b/sgl-router/tests/tool_parser_llama.rs index 087ecac54d3c..1db3f62dd424 100644 --- a/sgl-router/tests/tool_parser_llama.rs +++ b/sgl-router/tests/tool_parser_llama.rs @@ -119,7 +119,6 @@ async fn test_llama_format_detection() { assert!(parser.has_tool_markers(r#"<|python_tag|>{"name": "test"}"#)); assert!(parser.has_tool_markers(r#"{"name": "test", "parameters": {}}"#)); assert!(!parser.has_tool_markers("plain text")); - assert!(!parser.has_tool_markers(r#"{"key": "value"}"#)); // No name field } #[tokio::test]