Skip to content

Commit f8cb53d

Browse files
committed
feat(openai): add retry support to OpenAI router chat endpoint
- Add RetryConfig field and initialize from router config - Wrap route_chat request execution with RetryExecutor - Use is_retryable_status for 408, 429, 500, 502, 503, 504 status codes - Record circuit breaker outcomes correctly: - Failure on connection error or error status - Success only after body read (non-streaming) or stream start (streaming) - Record retry metrics with BACKEND_EXTERNAL label - Record duration/error metrics after retry completes - Responses endpoint not included (per requirement)
1 parent 796969c commit f8cb53d

File tree

1 file changed

+134
-89
lines changed
  • sgl-model-gateway/src/routers/openai

1 file changed

+134
-89
lines changed

sgl-model-gateway/src/routers/openai/router.rs

Lines changed: 134 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,11 @@ use super::{
3434
};
3535
use crate::{
3636
app_context::AppContext,
37-
core::{model_type::Endpoint, ModelCard, ProviderType, RuntimeType, Worker, WorkerRegistry},
37+
config::types::RetryConfig,
38+
core::{
39+
is_retryable_status, model_type::Endpoint, ModelCard, ProviderType, RetryExecutor,
40+
RuntimeType, Worker, WorkerRegistry,
41+
},
3842
data_connector::{ConversationId, ListParams, ResponseId, SortOrder},
3943
observability::metrics::{bool_to_static_str, metrics_labels, Metrics},
4044
protocols::{
@@ -53,6 +57,7 @@ pub struct OpenAIRouter {
5357
healthy: AtomicBool,
5458
shared_components: Arc<SharedComponents>,
5559
responses_components: Arc<ResponsesComponents>,
60+
retry_config: RetryConfig,
5661
}
5762

5863
impl std::fmt::Debug for OpenAIRouter {
@@ -176,6 +181,7 @@ impl OpenAIRouter {
176181
healthy: AtomicBool::new(true),
177182
shared_components,
178183
responses_components,
184+
retry_config: ctx.router_config.effective_retry_config(),
179185
})
180186
}
181187

@@ -659,77 +665,129 @@ impl crate::routers::RouterTrait for OpenAIRouter {
659665
previous_response_id: None,
660666
});
661667

668+
// Clone values needed for retry closure
662669
let payload_ref = ctx.payload().expect("Payload not prepared");
663-
let mut req = ctx.components.client().post(&url).json(&payload_ref.json);
664-
let auth_header = extract_auth_header(ctx.headers(), worker.api_key());
665-
req = apply_provider_headers(req, &url, auth_header.as_ref());
670+
let payload_json = payload_ref.json.clone();
671+
let client = ctx.components.client().clone();
672+
let headers_cloned = ctx.headers().cloned();
673+
let worker_api_key = worker.api_key().clone();
674+
let is_streaming = ctx.is_streaming();
675+
676+
let response = RetryExecutor::execute_response_with_retry(
677+
&self.retry_config,
678+
|_attempt| {
679+
let client = client.clone();
680+
let url = url.clone();
681+
let payload = payload_json.clone();
682+
let headers = headers_cloned.clone();
683+
let worker_api_key = worker_api_key.clone();
684+
let worker = Arc::clone(&worker);
685+
686+
async move {
687+
let mut req = client.post(&url).json(&payload);
688+
let auth_header = extract_auth_header(headers.as_ref(), &worker_api_key);
689+
req = apply_provider_headers(req, &url, auth_header.as_ref());
690+
691+
if is_streaming {
692+
req = req.header("Accept", "text/event-stream");
693+
}
666694

667-
if ctx.is_streaming() {
668-
req = req.header("Accept", "text/event-stream");
669-
}
695+
let resp = match req.send().await {
696+
Ok(r) => r,
697+
Err(e) => {
698+
worker.circuit_breaker().record_failure();
699+
return (
700+
StatusCode::SERVICE_UNAVAILABLE,
701+
format!("Failed to contact upstream: {}", e),
702+
)
703+
.into_response();
704+
}
705+
};
670706

671-
let resp = match req.send().await {
672-
Ok(r) => r,
673-
Err(e) => {
674-
worker.circuit_breaker().record_failure();
675-
Metrics::record_router_error(
676-
metrics_labels::ROUTER_OPENAI,
707+
let status = StatusCode::from_u16(resp.status().as_u16())
708+
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
709+
710+
// Record circuit breaker failure for error status codes
711+
if !status.is_success() {
712+
worker.circuit_breaker().record_failure();
713+
}
714+
715+
if !is_streaming {
716+
let content_type = resp.headers().get(CONTENT_TYPE).cloned();
717+
match resp.bytes().await {
718+
Ok(body) => {
719+
// Only record success after body is fully read
720+
if status.is_success() {
721+
worker.circuit_breaker().record_success();
722+
}
723+
let mut response = Response::new(Body::from(body));
724+
*response.status_mut() = status;
725+
if let Some(ct) = content_type {
726+
response.headers_mut().insert(CONTENT_TYPE, ct);
727+
}
728+
response
729+
}
730+
Err(e) => {
731+
worker.circuit_breaker().record_failure();
732+
(
733+
StatusCode::INTERNAL_SERVER_ERROR,
734+
format!("Failed to read response: {}", e),
735+
)
736+
.into_response()
737+
}
738+
}
739+
} else {
740+
// Streaming response - record success when stream starts
741+
if status.is_success() {
742+
worker.circuit_breaker().record_success();
743+
}
744+
let stream = resp.bytes_stream();
745+
let (tx, rx) = mpsc::unbounded_channel();
746+
tokio::spawn(async move {
747+
let mut s = stream;
748+
while let Some(chunk) = s.next().await {
749+
match chunk {
750+
Ok(bytes) => {
751+
if tx.send(Ok(bytes)).is_err() {
752+
break;
753+
}
754+
}
755+
Err(e) => {
756+
let _ = tx.send(Err(format!("Stream error: {}", e)));
757+
break;
758+
}
759+
}
760+
}
761+
});
762+
let mut response =
763+
Response::new(Body::from_stream(UnboundedReceiverStream::new(rx)));
764+
*response.status_mut() = status;
765+
response
766+
.headers_mut()
767+
.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
768+
response
769+
}
770+
}
771+
},
772+
|res, _attempt| is_retryable_status(res.status()),
773+
|delay, attempt| {
774+
Metrics::record_worker_retry(
677775
metrics_labels::BACKEND_EXTERNAL,
678-
metrics_labels::CONNECTION_HTTP,
679-
model,
680776
metrics_labels::ENDPOINT_CHAT,
681-
metrics_labels::ERROR_BACKEND,
682777
);
683-
return (
684-
StatusCode::SERVICE_UNAVAILABLE,
685-
format!("Failed to contact upstream: {}", e),
686-
)
687-
.into_response();
688-
}
689-
};
778+
Metrics::record_worker_retry_backoff(attempt, delay);
779+
},
780+
|| {
781+
Metrics::record_worker_retries_exhausted(
782+
metrics_labels::BACKEND_EXTERNAL,
783+
metrics_labels::ENDPOINT_CHAT,
784+
);
785+
},
786+
)
787+
.await;
690788

691-
let status = StatusCode::from_u16(resp.status().as_u16())
692-
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
693-
694-
if !ctx.is_streaming() {
695-
let content_type = resp.headers().get(CONTENT_TYPE).cloned();
696-
match resp.bytes().await {
697-
Ok(body) => {
698-
worker.circuit_breaker().record_success();
699-
Metrics::record_router_duration(
700-
metrics_labels::ROUTER_OPENAI,
701-
metrics_labels::BACKEND_EXTERNAL,
702-
metrics_labels::CONNECTION_HTTP,
703-
model,
704-
metrics_labels::ENDPOINT_CHAT,
705-
start.elapsed(),
706-
);
707-
let mut response = Response::new(Body::from(body));
708-
*response.status_mut() = status;
709-
if let Some(ct) = content_type {
710-
response.headers_mut().insert(CONTENT_TYPE, ct);
711-
}
712-
response
713-
}
714-
Err(e) => {
715-
worker.circuit_breaker().record_failure();
716-
Metrics::record_router_error(
717-
metrics_labels::ROUTER_OPENAI,
718-
metrics_labels::BACKEND_EXTERNAL,
719-
metrics_labels::CONNECTION_HTTP,
720-
model,
721-
metrics_labels::ENDPOINT_CHAT,
722-
metrics_labels::ERROR_BACKEND,
723-
);
724-
(
725-
StatusCode::INTERNAL_SERVER_ERROR,
726-
format!("Failed to read response: {}", e),
727-
)
728-
.into_response()
729-
}
730-
}
731-
} else {
732-
// For streaming, record duration at start since we can't track completion
789+
// Record duration/error metrics after retry completes
790+
if response.status().is_success() {
733791
Metrics::record_router_duration(
734792
metrics_labels::ROUTER_OPENAI,
735793
metrics_labels::BACKEND_EXTERNAL,
@@ -738,31 +796,18 @@ impl crate::routers::RouterTrait for OpenAIRouter {
738796
metrics_labels::ENDPOINT_CHAT,
739797
start.elapsed(),
740798
);
741-
let stream = resp.bytes_stream();
742-
let (tx, rx) = mpsc::unbounded_channel();
743-
tokio::spawn(async move {
744-
let mut s = stream;
745-
while let Some(chunk) = s.next().await {
746-
match chunk {
747-
Ok(bytes) => {
748-
if tx.send(Ok(bytes)).is_err() {
749-
break;
750-
}
751-
}
752-
Err(e) => {
753-
let _ = tx.send(Err(format!("Stream error: {}", e)));
754-
break;
755-
}
756-
}
757-
}
758-
});
759-
let mut response = Response::new(Body::from_stream(UnboundedReceiverStream::new(rx)));
760-
*response.status_mut() = status;
761-
response
762-
.headers_mut()
763-
.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
764-
response
799+
} else {
800+
Metrics::record_router_error(
801+
metrics_labels::ROUTER_OPENAI,
802+
metrics_labels::BACKEND_EXTERNAL,
803+
metrics_labels::CONNECTION_HTTP,
804+
model,
805+
metrics_labels::ENDPOINT_CHAT,
806+
metrics_labels::ERROR_BACKEND,
807+
);
765808
}
809+
810+
response
766811
}
767812

768813
async fn route_responses(

0 commit comments

Comments
 (0)