Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions sgl-model-gateway/src/routers/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ pub fn not_implemented(code: impl Into<String>, message: impl Into<String>) -> R
create_error(StatusCode::NOT_IMPLEMENTED, code, message)
}

pub fn bad_gateway(code: impl Into<String>, message: impl Into<String>) -> Response {
create_error(StatusCode::BAD_GATEWAY, code, message)
}

pub fn method_not_allowed(code: impl Into<String>, message: impl Into<String>) -> Response {
create_error(StatusCode::METHOD_NOT_ALLOWED, code, message)
}

fn create_error(
status: StatusCode,
code: impl Into<String>,
Expand Down
191 changes: 143 additions & 48 deletions sgl-model-gateway/src/routers/http/pd_router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use crate::{
generate::GenerateRequest,
rerank::RerankRequest,
},
routers::{header_utils, RouterTrait},
routers::{error, header_utils, RouterTrait},
};

#[derive(Debug)]
Expand Down Expand Up @@ -68,11 +68,7 @@ impl PDRouter {
if let Some(worker_url) = first_worker_url {
self.proxy_to_worker(worker_url, endpoint, headers).await
} else {
(
StatusCode::SERVICE_UNAVAILABLE,
"No prefill servers available".to_string(),
)
.into_response()
error::service_unavailable("no_prefill_servers", "No prefill servers available")
}
}

Expand Down Expand Up @@ -104,26 +100,50 @@ impl PDRouter {
}
Err(e) => {
error!("Failed to read response body: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
error::internal_error(
"read_response_body_failed",
format!("Failed to read response body: {}", e),
)
.into_response()
}
}
}
Ok(res) => {
let status = StatusCode::from_u16(res.status().as_u16())
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
(status, format!("{} server returned status: ", res.status())).into_response()
// Use the status code to determine which error function to use
match status {
StatusCode::BAD_REQUEST => error::bad_request(
"server_bad_request",
format!("Server returned status: {}", res.status()),
),
StatusCode::NOT_FOUND => error::not_found(
"server_not_found",
format!("Server returned status: {}", res.status()),
),
StatusCode::INTERNAL_SERVER_ERROR => error::internal_error(
"server_internal_error",
format!("Server returned status: {}", res.status()),
),
StatusCode::SERVICE_UNAVAILABLE => error::service_unavailable(
"server_unavailable",
format!("Server returned status: {}", res.status()),
),
StatusCode::BAD_GATEWAY => error::bad_gateway(
"server_bad_gateway",
format!("Server returned status: {}", res.status()),
),
_ => error::internal_error(
"server_error",
format!("Server returned status: {}", res.status()),
),
}
}
Err(e) => {
error!("Failed to proxy request server: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
error::internal_error(
"proxy_request_failed",
format!("Failed to proxy request: {}", e),
)
.into_response()
}
}
}
Expand All @@ -142,20 +162,15 @@ impl PDRouter {
fn handle_server_selection_error(error: String) -> Response {
error!("Failed to select PD pair error={}", error);
RouterMetrics::record_pd_error("server_selection");
(
StatusCode::SERVICE_UNAVAILABLE,
error::service_unavailable(
"server_selection_failed",
format!("No available servers: {}", error),
)
.into_response()
}

fn handle_serialization_error(error: impl std::fmt::Display) -> Response {
error!("Failed to serialize request error={}", error);
(
StatusCode::INTERNAL_SERVER_ERROR,
"Failed to serialize request",
)
.into_response()
error::internal_error("serialization_failed", "Failed to serialize request")
}

fn get_generate_batch_size(req: &GenerateRequest) -> Option<usize> {
Expand Down Expand Up @@ -378,8 +393,71 @@ impl PDRouter {
} else {
// Handle non-streaming error response
match res.bytes().await {
Ok(error_body) => (status, error_body).into_response(),
Err(e) => (status, format!("Decode server error: {}", e)).into_response(),
Ok(error_body) => {
// Try to parse error message from body, fallback to status-based error
let error_message = if let Ok(error_json) =
serde_json::from_slice::<Value>(&error_body)
{
if let Some(msg) = error_json
.get("error")
.and_then(|e| e.get("message"))
.and_then(|m| m.as_str())
{
msg.to_string()
} else if let Some(msg) = error_json.get("message").and_then(|m| m.as_str())
{
msg.to_string()
} else {
String::from_utf8_lossy(&error_body).to_string()
}
} else {
String::from_utf8_lossy(&error_body).to_string()
};

let status_code = StatusCode::from_u16(status.as_u16())
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
match status_code {
StatusCode::BAD_REQUEST => {
error::bad_request("decode_bad_request", error_message)
}
StatusCode::NOT_FOUND => {
error::not_found("decode_not_found", error_message)
}
StatusCode::INTERNAL_SERVER_ERROR => {
error::internal_error("decode_internal_error", error_message)
}
StatusCode::SERVICE_UNAVAILABLE => {
error::service_unavailable("decode_unavailable", error_message)
}
StatusCode::BAD_GATEWAY => {
error::bad_gateway("decode_bad_gateway", error_message)
}
_ => error::internal_error("decode_error", error_message),
}
}
Err(e) => {
let error_message = format!("Decode server error: {}", e);
let status_code = StatusCode::from_u16(status.as_u16())
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
match status_code {
StatusCode::BAD_REQUEST => {
error::bad_request("decode_read_failed", error_message)
}
StatusCode::NOT_FOUND => {
error::not_found("decode_read_failed", error_message)
}
StatusCode::INTERNAL_SERVER_ERROR => {
error::internal_error("decode_read_failed", error_message)
}
StatusCode::SERVICE_UNAVAILABLE => {
error::service_unavailable("decode_read_failed", error_message)
}
StatusCode::BAD_GATEWAY => {
error::bad_gateway("decode_read_failed", error_message)
}
_ => error::internal_error("decode_read_failed", error_message),
}
}
}
}
}
Expand Down Expand Up @@ -535,8 +613,10 @@ impl PDRouter {
}
Err(e) => {
error!("Failed to read decode response: {}", e);
(StatusCode::INTERNAL_SERVER_ERROR, "Failed to read response")
.into_response()
error::internal_error(
"read_response_failed",
"Failed to read response",
)
}
}
}
Expand All @@ -549,11 +629,7 @@ impl PDRouter {
"Decode request failed"
);
RouterMetrics::record_pd_decode_error(decode.url());
(
StatusCode::BAD_GATEWAY,
format!("Decode server error: {}", e),
)
.into_response()
error::bad_gateway("decode_server_error", format!("Decode server error: {}", e))
}
}
}
Expand Down Expand Up @@ -759,8 +835,7 @@ impl PDRouter {
Ok(decode_body) => decode_body,
Err(e) => {
error!("Failed to read decode response: {}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, "Failed to read response")
.into_response();
return error::internal_error("read_response_failed", "Failed to read response");
}
};

Expand Down Expand Up @@ -812,14 +887,13 @@ impl PDRouter {
);

// Return error immediately - don't wait for decode to timeout
return Err((
StatusCode::BAD_GATEWAY,
return Err(error::bad_gateway(
"prefill_server_error",
format!(
"Prefill server error: {}. This will cause decode timeout.",
e
),
)
.into_response());
));
}
};

Expand All @@ -841,11 +915,34 @@ impl PDRouter {
prefill_url, prefill_status, error_msg
);

return Err((
prefill_status,
format!("Prefill server error ({}): {}", prefill_status, error_msg),
)
.into_response());
// Map prefill_status to appropriate error function
let error_response = match prefill_status {
StatusCode::BAD_REQUEST => error::bad_request(
"prefill_bad_request",
format!("Prefill server error ({}): {}", prefill_status, error_msg),
),
StatusCode::NOT_FOUND => error::not_found(
"prefill_not_found",
format!("Prefill server error ({}): {}", prefill_status, error_msg),
),
StatusCode::INTERNAL_SERVER_ERROR => error::internal_error(
"prefill_internal_error",
format!("Prefill server error ({}): {}", prefill_status, error_msg),
),
StatusCode::SERVICE_UNAVAILABLE => error::service_unavailable(
"prefill_unavailable",
format!("Prefill server error ({}): {}", prefill_status, error_msg),
),
StatusCode::BAD_GATEWAY => error::bad_gateway(
"prefill_bad_gateway",
format!("Prefill server error ({}): {}", prefill_status, error_msg),
),
_ => error::internal_error(
"prefill_error",
format!("Prefill server error ({}): {}", prefill_status, error_msg),
),
};
return Err(error_response);
}

// Read prefill body if needed for logprob merging
Expand Down Expand Up @@ -990,11 +1087,10 @@ impl RouterTrait for PDRouter {
let (prefill, decode) = match self.select_pd_pair(None, None).await {
Ok(pair) => pair,
Err(e) => {
return (
StatusCode::SERVICE_UNAVAILABLE,
return error::service_unavailable(
"no_healthy_worker_pair",
format!("No healthy worker pair available: {}", e),
)
.into_response();
);
}
};

Expand Down Expand Up @@ -1055,11 +1151,10 @@ impl RouterTrait for PDRouter {
)
.into_response()
} else {
(
StatusCode::SERVICE_UNAVAILABLE,
error::service_unavailable(
"health_generate_failed",
format!("Health generate failed: {:?}", errors),
)
.into_response()
}
}

Expand Down
Loading
Loading