Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 134 additions & 89 deletions sgl-model-gateway/src/routers/openai/router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@ use super::{
};
use crate::{
app_context::AppContext,
core::{model_type::Endpoint, ModelCard, ProviderType, RuntimeType, Worker, WorkerRegistry},
config::types::RetryConfig,
core::{
is_retryable_status, model_type::Endpoint, ModelCard, ProviderType, RetryExecutor,
RuntimeType, Worker, WorkerRegistry,
},
data_connector::{ConversationId, ListParams, ResponseId, SortOrder},
observability::metrics::{bool_to_static_str, metrics_labels, Metrics},
protocols::{
Expand All @@ -53,6 +57,7 @@ pub struct OpenAIRouter {
healthy: AtomicBool,
shared_components: Arc<SharedComponents>,
responses_components: Arc<ResponsesComponents>,
retry_config: RetryConfig,
}

impl std::fmt::Debug for OpenAIRouter {
Expand Down Expand Up @@ -176,6 +181,7 @@ impl OpenAIRouter {
healthy: AtomicBool::new(true),
shared_components,
responses_components,
retry_config: ctx.router_config.effective_retry_config(),
})
}

Expand Down Expand Up @@ -659,77 +665,129 @@ impl crate::routers::RouterTrait for OpenAIRouter {
previous_response_id: None,
});

// Wrap values in Arc to avoid cloning large objects on each retry attempt
let payload_ref = ctx.payload().expect("Payload not prepared");
let mut req = ctx.components.client().post(&url).json(&payload_ref.json);
let auth_header = extract_auth_header(ctx.headers(), worker.api_key());
req = apply_provider_headers(req, &url, auth_header.as_ref());
let payload_json = Arc::new(payload_ref.json.clone());
let client = ctx.components.client().clone();
let headers_cloned = Arc::new(ctx.headers().cloned());
let worker_api_key = Arc::new(worker.api_key().clone());
let is_streaming = ctx.is_streaming();

let response = RetryExecutor::execute_response_with_retry(
&self.retry_config,
|_attempt| {
let client = client.clone();
let url = url.clone();
let payload = Arc::clone(&payload_json);
let headers = Arc::clone(&headers_cloned);
let worker_api_key = Arc::clone(&worker_api_key);
let worker = Arc::clone(&worker);

async move {
let mut req = client.post(&url).json(&*payload);
let auth_header = extract_auth_header((*headers).as_ref(), &worker_api_key);
req = apply_provider_headers(req, &url, auth_header.as_ref());

if is_streaming {
req = req.header("Accept", "text/event-stream");
}

if ctx.is_streaming() {
req = req.header("Accept", "text/event-stream");
}
let resp = match req.send().await {
Ok(r) => r,
Err(e) => {
worker.circuit_breaker().record_failure();
return (
StatusCode::SERVICE_UNAVAILABLE,
format!("Failed to contact upstream: {}", e),
)
.into_response();
}
};

let resp = match req.send().await {
Ok(r) => r,
Err(e) => {
worker.circuit_breaker().record_failure();
Metrics::record_router_error(
metrics_labels::ROUTER_OPENAI,
let status = StatusCode::from_u16(resp.status().as_u16())
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);

// Record circuit breaker failure for error status codes
if !status.is_success() {
worker.circuit_breaker().record_failure();
}

if !is_streaming {
let content_type = resp.headers().get(CONTENT_TYPE).cloned();
match resp.bytes().await {
Ok(body) => {
// Only record success after body is fully read
if status.is_success() {
worker.circuit_breaker().record_success();
}
let mut response = Response::new(Body::from(body));
*response.status_mut() = status;
if let Some(ct) = content_type {
response.headers_mut().insert(CONTENT_TYPE, ct);
}
response
}
Err(e) => {
worker.circuit_breaker().record_failure();
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to read response: {}", e),
)
.into_response()
}
}
} else {
// Streaming response - record success when stream starts
if status.is_success() {
worker.circuit_breaker().record_success();
}
let stream = resp.bytes_stream();
let (tx, rx) = mpsc::unbounded_channel();
tokio::spawn(async move {
let mut s = stream;
while let Some(chunk) = s.next().await {
match chunk {
Ok(bytes) => {
if tx.send(Ok(bytes)).is_err() {
break;
}
}
Err(e) => {
let _ = tx.send(Err(format!("Stream error: {}", e)));
break;
}
}
}
});
let mut response =
Response::new(Body::from_stream(UnboundedReceiverStream::new(rx)));
*response.status_mut() = status;
response
.headers_mut()
.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
response
}
}
},
|res, _attempt| is_retryable_status(res.status()),
|delay, attempt| {
Metrics::record_worker_retry(
metrics_labels::BACKEND_EXTERNAL,
metrics_labels::CONNECTION_HTTP,
model,
metrics_labels::ENDPOINT_CHAT,
metrics_labels::ERROR_BACKEND,
);
return (
StatusCode::SERVICE_UNAVAILABLE,
format!("Failed to contact upstream: {}", e),
)
.into_response();
}
};
Metrics::record_worker_retry_backoff(attempt, delay);
},
|| {
Metrics::record_worker_retries_exhausted(
metrics_labels::BACKEND_EXTERNAL,
metrics_labels::ENDPOINT_CHAT,
);
},
)
.await;

let status = StatusCode::from_u16(resp.status().as_u16())
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);

if !ctx.is_streaming() {
let content_type = resp.headers().get(CONTENT_TYPE).cloned();
match resp.bytes().await {
Ok(body) => {
worker.circuit_breaker().record_success();
Metrics::record_router_duration(
metrics_labels::ROUTER_OPENAI,
metrics_labels::BACKEND_EXTERNAL,
metrics_labels::CONNECTION_HTTP,
model,
metrics_labels::ENDPOINT_CHAT,
start.elapsed(),
);
let mut response = Response::new(Body::from(body));
*response.status_mut() = status;
if let Some(ct) = content_type {
response.headers_mut().insert(CONTENT_TYPE, ct);
}
response
}
Err(e) => {
worker.circuit_breaker().record_failure();
Metrics::record_router_error(
metrics_labels::ROUTER_OPENAI,
metrics_labels::BACKEND_EXTERNAL,
metrics_labels::CONNECTION_HTTP,
model,
metrics_labels::ENDPOINT_CHAT,
metrics_labels::ERROR_BACKEND,
);
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to read response: {}", e),
)
.into_response()
}
}
} else {
// For streaming, record duration at start since we can't track completion
// Record duration/error metrics after retry completes
if response.status().is_success() {
Metrics::record_router_duration(
metrics_labels::ROUTER_OPENAI,
metrics_labels::BACKEND_EXTERNAL,
Expand All @@ -738,31 +796,18 @@ impl crate::routers::RouterTrait for OpenAIRouter {
metrics_labels::ENDPOINT_CHAT,
start.elapsed(),
);
let stream = resp.bytes_stream();
let (tx, rx) = mpsc::unbounded_channel();
tokio::spawn(async move {
let mut s = stream;
while let Some(chunk) = s.next().await {
match chunk {
Ok(bytes) => {
if tx.send(Ok(bytes)).is_err() {
break;
}
}
Err(e) => {
let _ = tx.send(Err(format!("Stream error: {}", e)));
break;
}
}
}
});
let mut response = Response::new(Body::from_stream(UnboundedReceiverStream::new(rx)));
*response.status_mut() = status;
response
.headers_mut()
.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
response
} else {
Metrics::record_router_error(
metrics_labels::ROUTER_OPENAI,
metrics_labels::BACKEND_EXTERNAL,
metrics_labels::CONNECTION_HTTP,
model,
metrics_labels::ENDPOINT_CHAT,
metrics_labels::ERROR_BACKEND,
);
}

response
}

async fn route_responses(
Expand Down
Loading