Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions python/sglang/srt/entrypoints/grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ def _convert_generate_request(
logprob_start_len=grpc_req.logprob_start_len or -1,
top_logprobs_num=grpc_req.top_logprobs_num or 0,
stream=grpc_req.stream or False,
lora_path=grpc_req.lora_id if grpc_req.lora_id else None,
lora_id=grpc_req.lora_id if grpc_req.lora_id else None,
token_ids_logprob=(
list(grpc_req.token_ids_logprob) if grpc_req.token_ids_logprob else None
),
Expand Down Expand Up @@ -458,9 +458,9 @@ def _convert_sampling_params(
repetition_penalty=grpc_params.repetition_penalty or 1.0,
max_new_tokens=grpc_params.max_new_tokens or 128,
min_new_tokens=grpc_params.min_new_tokens or 0,
stop=list(grpc_params.stop) if grpc_params.stop else None,
stop=list(grpc_params.stop) if grpc_params.stop else [],
stop_token_ids=(
list(grpc_params.stop_token_ids) if grpc_params.stop_token_ids else None
list(grpc_params.stop_token_ids) if grpc_params.stop_token_ids else []
),
skip_special_tokens=grpc_params.skip_special_tokens,
spaces_between_special_tokens=grpc_params.spaces_between_special_tokens,
Expand Down
52 changes: 50 additions & 2 deletions sgl-router/src/tokenizer/factory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,36 @@ fn is_likely_sentencepiece(buffer: &[u8]) -> bool {
|| buffer.windows(4).any(|w| w == b"</s>"))
}

/// Helper function to discover chat template files in a directory
pub fn discover_chat_template_in_dir(dir: &Path) -> Option<String> {
use std::fs;

// Priority 1: Look for chat_template.json (contains Jinja in JSON format)
let json_template_path = dir.join("chat_template.json");
if json_template_path.exists() {
return json_template_path.to_str().map(|s| s.to_string());
}

// Priority 2: Look for chat_template.jinja (standard Jinja file)
let jinja_path = dir.join("chat_template.jinja");
if jinja_path.exists() {
return jinja_path.to_str().map(|s| s.to_string());
}

// Priority 3: Look for any .jinja file (for models with non-standard naming)
if let Ok(entries) = fs::read_dir(dir) {
for entry in entries.flatten() {
if let Some(name) = entry.file_name().to_str() {
if name.ends_with(".jinja") && name != "chat_template.jinja" {
return entry.path().to_str().map(|s| s.to_string());
}
}
}
}

None
}

/// Factory function to create tokenizer from a model name or path (async version)
pub async fn create_tokenizer_async(
model_name_or_path: &str,
Expand Down Expand Up @@ -161,14 +191,32 @@ pub async fn create_tokenizer_async(
// Look for tokenizer.json in the cache directory
let tokenizer_path = cache_dir.join("tokenizer.json");
if tokenizer_path.exists() {
create_tokenizer_from_file(tokenizer_path.to_str().unwrap())
// Try to find a chat template file in the cache directory
let chat_template_path = discover_chat_template_in_dir(&cache_dir);
let tokenizer_path_str = tokenizer_path.to_str().ok_or_else(|| {
Error::msg(format!(
"Tokenizer path is not valid UTF-8: {:?}",
tokenizer_path
))
})?;
create_tokenizer_with_chat_template(
tokenizer_path_str,
chat_template_path.as_deref(),
)
} else {
// Try other common tokenizer file names
let possible_files = ["tokenizer_config.json", "vocab.json"];
for file_name in &possible_files {
let file_path = cache_dir.join(file_name);
if file_path.exists() {
return create_tokenizer_from_file(file_path.to_str().unwrap());
let chat_template_path = discover_chat_template_in_dir(&cache_dir);
let file_path_str = file_path.to_str().ok_or_else(|| {
Error::msg(format!("File path is not valid UTF-8: {:?}", file_path))
})?;
return create_tokenizer_with_chat_template(
file_path_str,
chat_template_path.as_deref(),
);
}
}
Err(Error::msg(format!(
Expand Down
96 changes: 94 additions & 2 deletions sgl-router/src/tokenizer/hub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,13 @@ fn is_tokenizer_file(filename: &str) -> bool {
|| filename.ends_with("merges.txt")
|| filename.ends_with(".model") // SentencePiece models
|| filename.ends_with(".tiktoken")
|| is_chat_template_file(filename) // Include chat template files
}

/// Checks if a file is a chat template file
fn is_chat_template_file(filename: &str) -> bool {
filename.ends_with(".jinja") // Direct Jinja files
|| filename == "chat_template.json" // JSON file containing Jinja template
}

/// Attempt to download tokenizer files from Hugging Face
Expand Down Expand Up @@ -123,7 +130,13 @@ pub async fn download_tokenizer_from_hf(model_id: impl AsRef<Path>) -> anyhow::R
}

match cache_dir {
Some(dir) => Ok(dir),
Some(dir) => {
// Ensure we return the correct model directory, not a subfolder
// Some models have an "original" subfolder for PyTorch weights
// We want the main model directory that contains tokenizer files
let final_dir = resolve_model_cache_dir(&dir, &model_name);
Ok(final_dir)
}
None => Err(anyhow::anyhow!(
"Invalid HF cache path for model '{}'",
model_name
Expand Down Expand Up @@ -206,11 +219,76 @@ pub async fn from_hf(name: impl AsRef<Path>, ignore_weights: bool) -> anyhow::Re
}

match p.parent() {
Some(p) => Ok(p.to_path_buf()),
Some(p) => {
let final_dir = resolve_model_cache_dir(p, &model_name);
Ok(final_dir)
}
None => Err(anyhow::anyhow!("Invalid HF cache path: {}", p.display())),
}
}

/// Resolve the correct model cache directory
/// Handles cases where files might be in subfolders (e.g., "original" folder)
fn resolve_model_cache_dir(path: &Path, model_name: &str) -> PathBuf {
// Check if we're in a subfolder like "original"
if let Some(parent) = path.parent() {
if let Some(folder_name) = path.file_name() {
if folder_name == "original" {
// We're in the "original" subfolder, go up one level
return parent.to_path_buf();
}
}
}

// Check if the current path contains the model name components
// This helps ensure we're at the right directory level
let model_parts: Vec<&str> = model_name.split('/').collect();
if model_parts.len() >= 2 {
let expected_pattern = format!(
"models--{}--{}",
model_parts[0].replace("-", "--"),
model_parts[1].replace("-", "--")
);

if path.to_string_lossy().contains(&expected_pattern) {
// We're already at the correct level
return path.to_path_buf();
}

let mut current = path.to_path_buf();

// First check if current path already contains tokenizer files
if current.join("tokenizer.json").exists() || current.join("tokenizer_config.json").exists()
{
return current;
}

// If not, traverse up to find the model root, then look in snapshots
while let Some(parent) = current.parent() {
if parent.to_string_lossy().contains(&expected_pattern) {
let snapshots_dir = parent.join("snapshots");
if snapshots_dir.exists() && snapshots_dir.is_dir() {
if let Ok(entries) = std::fs::read_dir(&snapshots_dir) {
for entry in entries.flatten() {
let snapshot_path = entry.path();
if snapshot_path.is_dir()
&& (snapshot_path.join("tokenizer.json").exists()
|| snapshot_path.join("tokenizer_config.json").exists())
{
return snapshot_path;
}
}
}
}
return parent.to_path_buf();
}
current = parent.to_path_buf();
}
}

path.to_path_buf()
}

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -223,10 +301,24 @@ mod tests {
assert!(is_tokenizer_file("vocab.json"));
assert!(is_tokenizer_file("merges.txt"));
assert!(is_tokenizer_file("spiece.model"));
assert!(is_tokenizer_file("chat_template.jinja"));
assert!(is_tokenizer_file("template.jinja"));
assert!(!is_tokenizer_file("model.bin"));
assert!(!is_tokenizer_file("README.md"));
}

#[test]
fn test_is_chat_template_file() {
assert!(is_chat_template_file("chat_template.jinja"));
assert!(is_chat_template_file("template.jinja"));
assert!(is_chat_template_file("any_file.jinja"));
assert!(is_chat_template_file("chat_template.json"));
assert!(!is_chat_template_file("tokenizer.json"));
assert!(!is_chat_template_file("other_file.json"));
assert!(!is_chat_template_file("chat_template"));
assert!(!is_chat_template_file("README.md"));
}

#[test]
fn test_is_weight_file() {
assert!(is_weight_file("model.bin"));
Expand Down
31 changes: 29 additions & 2 deletions sgl-router/src/tokenizer/huggingface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,12 @@ pub struct HuggingFaceTokenizer {
impl HuggingFaceTokenizer {
/// Create a tokenizer from a HuggingFace tokenizer JSON file
pub fn from_file(file_path: &str) -> Result<Self> {
Self::from_file_with_chat_template(file_path, None)
// Try to auto-discover chat template if not explicitly provided
let path = std::path::Path::new(file_path);
let chat_template_path = path
.parent()
.and_then(crate::tokenizer::factory::discover_chat_template_in_dir);
Self::from_file_with_chat_template(file_path, chat_template_path.as_deref())
}

/// Create a tokenizer from a HuggingFace tokenizer JSON file with an optional chat template
Expand Down Expand Up @@ -135,13 +140,35 @@ impl HuggingFaceTokenizer {
None
}

/// Load chat template from a .jinja file
/// Load chat template from a file (.jinja or .json containing Jinja)
fn load_chat_template_from_file(template_path: &str) -> Result<Option<String>> {
use std::fs;

let content = fs::read_to_string(template_path)
.map_err(|e| Error::msg(format!("Failed to read chat template file: {}", e)))?;

// Check if it's a JSON file containing a Jinja template
if template_path.ends_with(".json") {
// Parse JSON and extract the template string
let json_value: serde_json::Value = serde_json::from_str(&content)
.map_err(|e| Error::msg(format!("Failed to parse chat_template.json: {}", e)))?;

if let Some(template_str) = json_value.as_str() {
return Ok(Some(template_str.to_string()));
} else if let Some(obj) = json_value.as_object() {
if let Some(template_value) = obj.get("chat_template") {
if let Some(template_str) = template_value.as_str() {
return Ok(Some(template_str.to_string()));
}
}
}

return Err(Error::msg(
"chat_template.json does not contain a valid template",
));
}

// Otherwise it's a plain .jinja file
// Clean up the template (similar to Python implementation)
let template = content.trim().replace("\\n", "\n");

Expand Down
Loading
Loading