Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions python/sglang/srt/entrypoints/grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from grpc_reflection.v1alpha import reflection

import sglang
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST, DisaggregationMode
from sglang.srt.grpc import sglang_scheduler_pb2, sglang_scheduler_pb2_grpc
from sglang.srt.grpc.grpc_request_manager import GrpcRequestManager
Expand Down Expand Up @@ -321,7 +322,8 @@ async def GetModelInfo(
max_context_length=self.model_info["max_context_length"],
vocab_size=self.model_info["vocab_size"],
supports_vision=self.model_info["supports_vision"],
model_type=self.model_info["model_type"],
model_type=self.model_info.get("model_type") or "",
architectures=self.model_info.get("architectures") or [],
eos_token_ids=self.model_info["eos_token_ids"],
pad_token_id=self.model_info["pad_token_id"],
bos_token_id=self.model_info["bos_token_id"],
Expand Down Expand Up @@ -718,7 +720,10 @@ async def serve_grpc(
server_args=server_args,
)

# Update model info from scheduler info
# Load model config to get HF config info (same as TokenizerManager does)
model_config = ModelConfig.from_server_args(server_args)

# Update model info from scheduler info and model config
if model_info is None:
model_info = {
"model_name": server_args.model_path,
Expand All @@ -727,7 +732,8 @@ async def serve_grpc(
),
"vocab_size": scheduler_info.get("vocab_size", 128256),
"supports_vision": scheduler_info.get("supports_vision", False),
"model_type": scheduler_info.get("model_type", "transformer"),
"model_type": getattr(model_config.hf_config, "model_type", None),
"architectures": getattr(model_config.hf_config, "architectures", None),
"max_req_input_len": scheduler_info.get("max_req_input_len", 8192),
"eos_token_ids": scheduler_info.get("eos_token_ids", []),
"pad_token_id": scheduler_info.get("pad_token_id", 0),
Expand Down
7 changes: 5 additions & 2 deletions python/sglang/srt/entrypoints/http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,14 +497,17 @@ async def get_model_info():
@app.get("/model_info")
async def model_info():
"""Get the model information."""
model_config = _global_state.tokenizer_manager.model_config
result = {
"model_path": _global_state.tokenizer_manager.model_path,
"tokenizer_path": _global_state.tokenizer_manager.server_args.tokenizer_path,
"is_generation": _global_state.tokenizer_manager.is_generation,
"preferred_sampling_params": _global_state.tokenizer_manager.server_args.preferred_sampling_params,
"weight_version": _global_state.tokenizer_manager.server_args.weight_version,
"has_image_understanding": _global_state.tokenizer_manager.model_config.is_image_understandable_model,
"has_audio_understanding": _global_state.tokenizer_manager.model_config.is_audio_understandable_model,
"has_image_understanding": model_config.is_image_understandable_model,
"has_audio_understanding": model_config.is_audio_understandable_model,
"model_type": getattr(model_config.hf_config, "model_type", None),
"architectures": getattr(model_config.hf_config, "architectures", None),
}
return result

Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/grpc/sglang_scheduler.proto
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,7 @@ message GetModelInfoResponse {
int32 pad_token_id = 12;
int32 bos_token_id = 13;
int32 max_req_input_len = 14;
repeated string architectures = 15;
}

// Get server information
Expand Down
16 changes: 8 additions & 8 deletions python/sglang/srt/grpc/sglang_scheduler_pb2.py

Large diffs are not rendered by default.

6 changes: 4 additions & 2 deletions python/sglang/srt/grpc/sglang_scheduler_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ class GetModelInfoRequest(_message.Message):
def __init__(self) -> None: ...

class GetModelInfoResponse(_message.Message):
__slots__ = ("model_path", "tokenizer_path", "is_generation", "preferred_sampling_params", "weight_version", "served_model_name", "max_context_length", "vocab_size", "supports_vision", "model_type", "eos_token_ids", "pad_token_id", "bos_token_id", "max_req_input_len")
__slots__ = ("model_path", "tokenizer_path", "is_generation", "preferred_sampling_params", "weight_version", "served_model_name", "max_context_length", "vocab_size", "supports_vision", "model_type", "eos_token_ids", "pad_token_id", "bos_token_id", "max_req_input_len", "architectures")
MODEL_PATH_FIELD_NUMBER: _ClassVar[int]
TOKENIZER_PATH_FIELD_NUMBER: _ClassVar[int]
IS_GENERATION_FIELD_NUMBER: _ClassVar[int]
Expand All @@ -447,6 +447,7 @@ class GetModelInfoResponse(_message.Message):
PAD_TOKEN_ID_FIELD_NUMBER: _ClassVar[int]
BOS_TOKEN_ID_FIELD_NUMBER: _ClassVar[int]
MAX_REQ_INPUT_LEN_FIELD_NUMBER: _ClassVar[int]
ARCHITECTURES_FIELD_NUMBER: _ClassVar[int]
model_path: str
tokenizer_path: str
is_generation: bool
Expand All @@ -461,7 +462,8 @@ class GetModelInfoResponse(_message.Message):
pad_token_id: int
bos_token_id: int
max_req_input_len: int
def __init__(self, model_path: _Optional[str] = ..., tokenizer_path: _Optional[str] = ..., is_generation: bool = ..., preferred_sampling_params: _Optional[str] = ..., weight_version: _Optional[str] = ..., served_model_name: _Optional[str] = ..., max_context_length: _Optional[int] = ..., vocab_size: _Optional[int] = ..., supports_vision: bool = ..., model_type: _Optional[str] = ..., eos_token_ids: _Optional[_Iterable[int]] = ..., pad_token_id: _Optional[int] = ..., bos_token_id: _Optional[int] = ..., max_req_input_len: _Optional[int] = ...) -> None: ...
architectures: _containers.RepeatedScalarFieldContainer[str]
def __init__(self, model_path: _Optional[str] = ..., tokenizer_path: _Optional[str] = ..., is_generation: bool = ..., preferred_sampling_params: _Optional[str] = ..., weight_version: _Optional[str] = ..., served_model_name: _Optional[str] = ..., max_context_length: _Optional[int] = ..., vocab_size: _Optional[int] = ..., supports_vision: bool = ..., model_type: _Optional[str] = ..., eos_token_ids: _Optional[_Iterable[int]] = ..., pad_token_id: _Optional[int] = ..., bos_token_id: _Optional[int] = ..., max_req_input_len: _Optional[int] = ..., architectures: _Optional[_Iterable[str]] = ...) -> None: ...

class GetServerInfoRequest(_message.Message):
__slots__ = ()
Expand Down
23 changes: 23 additions & 0 deletions sgl-router/src/core/model_card.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,15 @@ pub struct ModelCard {
#[serde(default = "default_model_type")]
pub model_type: ModelType,

/// HuggingFace model type string (e.g., "llama", "qwen2", "gpt-oss")
/// This is different from `model_type` which is capability bitflags.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub hf_model_type: Option<String>,

/// Model architectures from HuggingFace config (e.g., ["LlamaForCausalLM"])
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub architectures: Vec<String>,

/// Provider hint for API transformations.
/// `None` means native/passthrough (no transformation needed).
#[serde(default, skip_serializing_if = "Option::is_none")]
Expand Down Expand Up @@ -172,6 +181,8 @@ impl ModelCard {
display_name: None,
aliases: Vec::new(),
model_type: ModelType::LLM,
hf_model_type: None,
architectures: Vec::new(),
provider: None,
context_length: None,
tokenizer_path: None,
Expand Down Expand Up @@ -208,6 +219,18 @@ impl ModelCard {
self
}

/// Set the HuggingFace model type string
pub fn with_hf_model_type(mut self, hf_model_type: impl Into<String>) -> Self {
self.hf_model_type = Some(hf_model_type.into());
self
}

/// Set the model architectures
pub fn with_architectures(mut self, architectures: Vec<String>) -> Self {
self.architectures = architectures;
self
}

/// Set the provider type (for external API transformations)
pub fn with_provider(mut self, provider: ProviderType) -> Self {
self.provider = Some(provider);
Expand Down
93 changes: 80 additions & 13 deletions sgl-router/src/core/workflow/steps/worker_registration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,18 @@ struct ServerInfo {
max_num_reqs: Option<usize>,
}

/// Model information returned from /model_info endpoint
#[derive(Debug, Clone, Deserialize, Serialize)]
struct ModelInfo {
model_path: Option<String>,
tokenizer_path: Option<String>,
is_generation: Option<bool>,
/// HuggingFace model type string (e.g., "llama", "qwen2", "gpt_oss")
model_type: Option<String>,
/// Model architectures from HuggingFace config (e.g., ["LlamaForCausalLM"])
architectures: Option<Vec<String>>,
}

#[derive(Debug, Clone)]
pub struct DpInfo {
pub dp_size: usize,
Expand All @@ -69,7 +81,7 @@ fn parse_server_info(json: Value) -> Result<ServerInfo, String> {
/// Get server info from /get_server_info endpoint
async fn get_server_info(url: &str, api_key: Option<&str>) -> Result<ServerInfo, String> {
let base_url = url.trim_end_matches('/');
let server_info_url = format!("{}/get_server_info", base_url);
let server_info_url = format!("{}/server_info", base_url);

let mut req = HTTP_CLIENT.get(&server_info_url);
if let Some(key) = api_key {
Expand Down Expand Up @@ -97,6 +109,35 @@ async fn get_server_info(url: &str, api_key: Option<&str>) -> Result<ServerInfo,
parse_server_info(json)
}

/// Get model info from /model_info endpoint
async fn get_model_info(url: &str, api_key: Option<&str>) -> Result<ModelInfo, String> {
let base_url = url.trim_end_matches('/');
let model_info_url = format!("{}/model_info", base_url);

let mut req = HTTP_CLIENT.get(&model_info_url);
if let Some(key) = api_key {
req = req.bearer_auth(key);
}

let response = req
.send()
.await
.map_err(|e| format!("Failed to connect to {}: {}", model_info_url, e))?;

if !response.status().is_success() {
return Err(format!(
"Server returned status {} from {}",
response.status(),
model_info_url
));
}

response
.json::<ModelInfo>()
.await
.map_err(|e| format!("Failed to parse response from {}: {}", model_info_url, e))
}

/// Get DP info for a worker URL
async fn get_dp_info(url: &str, api_key: Option<&str>) -> Result<DpInfo, String> {
let info = get_server_info(url, api_key).await?;
Expand Down Expand Up @@ -319,21 +360,37 @@ impl StepExecutor for DiscoverMetadataStep {

let (discovered_labels, detected_runtime) = match connection_mode.as_ref() {
ConnectionMode::Http => {
match get_server_info(&config.url, config.api_key.as_deref()).await {
Ok(server_info) => {
let mut labels = HashMap::new();
if let Some(model_path) = server_info.model_path.filter(|s| !s.is_empty()) {
labels.insert("model_path".to_string(), model_path);
}
if let Some(served_model_name) =
server_info.served_model_name.filter(|s| !s.is_empty())
{
labels.insert("served_model_name".to_string(), served_model_name);
let mut labels = HashMap::new();

// Fetch from /get_server_info for server-related metadata
if let Ok(server_info) =
get_server_info(&config.url, config.api_key.as_deref()).await
{
if let Some(model_path) = server_info.model_path.filter(|s| !s.is_empty()) {
labels.insert("model_path".to_string(), model_path);
}
if let Some(served_model_name) =
server_info.served_model_name.filter(|s| !s.is_empty())
{
labels.insert("served_model_name".to_string(), served_model_name);
}
}

// Fetch from /model_info for model-related metadata (model_type, architectures)
if let Ok(model_info) = get_model_info(&config.url, config.api_key.as_deref()).await
{
if let Some(model_type) = model_info.model_type.filter(|s| !s.is_empty()) {
labels.insert("model_type".to_string(), model_type);
}
if let Some(architectures) = model_info.architectures.filter(|a| !a.is_empty())
{
if let Ok(json_str) = serde_json::to_string(&architectures) {
labels.insert("architectures".to_string(), json_str);
}
Ok((labels, None))
}
Err(e) => Err(e),
}

Ok((labels, None))
}
ConnectionMode::Grpc { .. } => {
let runtime_type = config.runtime.as_deref();
Expand Down Expand Up @@ -481,6 +538,16 @@ impl StepExecutor for CreateWorkerStep {
if let Some(ref chat_template) = config.chat_template {
card = card.with_chat_template(chat_template.clone());
}
// Set HuggingFace model type from discovered labels
if let Some(model_type_str) = final_labels.get("model_type") {
card = card.with_hf_model_type(model_type_str.clone());
}
// Set architectures from discovered labels (JSON array string)
if let Some(architectures_json) = final_labels.get("architectures") {
if let Ok(architectures) = serde_json::from_str::<Vec<String>>(architectures_json) {
card = card.with_architectures(architectures);
}
}
card
};

Expand Down
1 change: 1 addition & 0 deletions sgl-router/src/proto/sglang_scheduler.proto
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,7 @@ message GetModelInfoResponse {
int32 pad_token_id = 12;
int32 bos_token_id = 13;
int32 max_req_input_len = 14;
repeated string architectures = 15;
}

// Get server information
Expand Down
8 changes: 7 additions & 1 deletion sgl-router/src/routers/grpc/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,13 @@ impl ModelInfo {
serde_json::Value::Bool(true) => {
labels.insert(key, "true".to_string());
}
// Skip empty strings, zeros, false, nulls, arrays, objects
// Insert non-empty arrays as JSON strings (for architectures, etc.)
serde_json::Value::Array(arr) if !arr.is_empty() => {
if let Ok(json_str) = serde_json::to_string(&arr) {
labels.insert(key, json_str);
}
}
// Skip empty strings, zeros, false, nulls, empty arrays, objects
_ => {}
}
}
Expand Down
60 changes: 60 additions & 0 deletions sgl-router/src/routers/grpc/harmony/detector.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,48 @@
//! Harmony model detection

use crate::core::{Worker, WorkerRegistry};

/// Harmony model detector
///
/// Detects if a model name indicates support for Harmony encoding/parsing.
pub struct HarmonyDetector;

impl HarmonyDetector {
/// Check if a worker is a Harmony/GPT-OSS model.
///
/// Detection priority:
/// 1. Check if any model card has architectures containing "GptOssForCausalLM"
/// 2. Check if any model card has hf_model_type equal to "gpt_oss"
/// 3. Check if model_id contains "gpt-oss" substring (case-insensitive)
pub fn is_harmony_worker(worker: &dyn Worker) -> bool {
for model_card in worker.models() {
// 1. Check architectures for GptOssForCausalLM
if model_card
.architectures
.iter()
.any(|arch| arch == "GptOssForCausalLM")
{
return true;
}

// 2. Check hf_model_type for gpt_oss
if let Some(ref model_type) = model_card.hf_model_type {
if model_type == "gpt_oss" {
return true;
}
}

// 3. Check model id for gpt-oss substring
if Self::is_harmony_model(&model_card.id) {
return true;
}
}

// Fallback: check worker's model_id directly
Self::is_harmony_model(worker.model_id())
}

/// Check if a model name contains "gpt-oss" (case-insensitive).
pub fn is_harmony_model(model_name: &str) -> bool {
// Case-insensitive substring search without heap allocation
// More efficient than to_lowercase() which allocates a new String
Expand All @@ -14,4 +51,27 @@ impl HarmonyDetector {
.windows(7) // "gpt-oss".len()
.any(|window| window.eq_ignore_ascii_case(b"gpt-oss"))
}

/// Check if any worker for the given model is a Harmony/GPT-OSS worker.
///
/// This method looks up workers from the registry by model name and checks
/// if any of them are Harmony workers based on their metadata (architectures,
/// hf_model_type).
///
/// Falls back to string-based detection if no workers are registered for
/// the model (e.g., during startup before workers are discovered).
pub fn is_harmony_model_in_registry(registry: &WorkerRegistry, model_name: &str) -> bool {
// Get workers for this model
let workers = registry.get_by_model_fast(model_name);

if workers.is_empty() {
// No workers found - fall back to string-based detection
return Self::is_harmony_model(model_name);
}

// Check if any worker is a Harmony worker
workers
.iter()
.any(|worker| Self::is_harmony_worker(worker.as_ref()))
}
}
10 changes: 6 additions & 4 deletions sgl-router/src/routers/grpc/router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,9 @@ impl GrpcRouter {
body: &ChatCompletionRequest,
model_id: Option<&str>,
) -> Response {
// Choose Harmony pipeline if model indicates Harmony
let is_harmony = HarmonyDetector::is_harmony_model(&body.model);
// Choose Harmony pipeline if workers indicate Harmony (checks architectures, hf_model_type)
let is_harmony =
HarmonyDetector::is_harmony_model_in_registry(&self.worker_registry, &body.model);

debug!(
"Processing chat completion request for model: {:?}, using_harmony={}",
Expand Down Expand Up @@ -203,8 +204,9 @@ impl GrpcRouter {
return error_response;
}

// Choose implementation based on Harmony model detection
let is_harmony = HarmonyDetector::is_harmony_model(&body.model);
// Choose implementation based on Harmony model detection (checks worker metadata)
let is_harmony =
HarmonyDetector::is_harmony_model_in_registry(&self.worker_registry, &body.model);

if is_harmony {
debug!(
Expand Down
Loading