@@ -34,7 +34,11 @@ use super::{
3434} ;
3535use crate :: {
3636 app_context:: AppContext ,
37- core:: { model_type:: Endpoint , ModelCard , ProviderType , RuntimeType , Worker , WorkerRegistry } ,
37+ config:: types:: RetryConfig ,
38+ core:: {
39+ is_retryable_status, model_type:: Endpoint , ModelCard , ProviderType , RetryExecutor ,
40+ RuntimeType , Worker , WorkerRegistry ,
41+ } ,
3842 data_connector:: { ConversationId , ListParams , ResponseId , SortOrder } ,
3943 observability:: metrics:: { bool_to_static_str, metrics_labels, Metrics } ,
4044 protocols:: {
@@ -53,6 +57,7 @@ pub struct OpenAIRouter {
5357 healthy : AtomicBool ,
5458 shared_components : Arc < SharedComponents > ,
5559 responses_components : Arc < ResponsesComponents > ,
60+ retry_config : RetryConfig ,
5661}
5762
5863impl std:: fmt:: Debug for OpenAIRouter {
@@ -176,6 +181,7 @@ impl OpenAIRouter {
176181 healthy : AtomicBool :: new ( true ) ,
177182 shared_components,
178183 responses_components,
184+ retry_config : ctx. router_config . effective_retry_config ( ) ,
179185 } )
180186 }
181187
@@ -659,77 +665,129 @@ impl crate::routers::RouterTrait for OpenAIRouter {
659665 previous_response_id : None ,
660666 } ) ;
661667
668+ // Clone values needed for retry closure
662669 let payload_ref = ctx. payload ( ) . expect ( "Payload not prepared" ) ;
663- let mut req = ctx. components . client ( ) . post ( & url) . json ( & payload_ref. json ) ;
664- let auth_header = extract_auth_header ( ctx. headers ( ) , worker. api_key ( ) ) ;
665- req = apply_provider_headers ( req, & url, auth_header. as_ref ( ) ) ;
670+ let payload_json = payload_ref. json . clone ( ) ;
671+ let client = ctx. components . client ( ) . clone ( ) ;
672+ let headers_cloned = ctx. headers ( ) . cloned ( ) ;
673+ let worker_api_key = worker. api_key ( ) . clone ( ) ;
674+ let is_streaming = ctx. is_streaming ( ) ;
675+
676+ let response = RetryExecutor :: execute_response_with_retry (
677+ & self . retry_config ,
678+ |_attempt| {
679+ let client = client. clone ( ) ;
680+ let url = url. clone ( ) ;
681+ let payload = payload_json. clone ( ) ;
682+ let headers = headers_cloned. clone ( ) ;
683+ let worker_api_key = worker_api_key. clone ( ) ;
684+ let worker = Arc :: clone ( & worker) ;
685+
686+ async move {
687+ let mut req = client. post ( & url) . json ( & payload) ;
688+ let auth_header = extract_auth_header ( headers. as_ref ( ) , & worker_api_key) ;
689+ req = apply_provider_headers ( req, & url, auth_header. as_ref ( ) ) ;
690+
691+ if is_streaming {
692+ req = req. header ( "Accept" , "text/event-stream" ) ;
693+ }
666694
667- if ctx. is_streaming ( ) {
668- req = req. header ( "Accept" , "text/event-stream" ) ;
669- }
695+ let resp = match req. send ( ) . await {
696+ Ok ( r) => r,
697+ Err ( e) => {
698+ worker. circuit_breaker ( ) . record_failure ( ) ;
699+ return (
700+ StatusCode :: SERVICE_UNAVAILABLE ,
701+ format ! ( "Failed to contact upstream: {}" , e) ,
702+ )
703+ . into_response ( ) ;
704+ }
705+ } ;
670706
671- let resp = match req. send ( ) . await {
672- Ok ( r) => r,
673- Err ( e) => {
674- worker. circuit_breaker ( ) . record_failure ( ) ;
675- Metrics :: record_router_error (
676- metrics_labels:: ROUTER_OPENAI ,
707+ let status = StatusCode :: from_u16 ( resp. status ( ) . as_u16 ( ) )
708+ . unwrap_or ( StatusCode :: INTERNAL_SERVER_ERROR ) ;
709+
710+ // Record circuit breaker failure for error status codes
711+ if !status. is_success ( ) {
712+ worker. circuit_breaker ( ) . record_failure ( ) ;
713+ }
714+
715+ if !is_streaming {
716+ let content_type = resp. headers ( ) . get ( CONTENT_TYPE ) . cloned ( ) ;
717+ match resp. bytes ( ) . await {
718+ Ok ( body) => {
719+ // Only record success after body is fully read
720+ if status. is_success ( ) {
721+ worker. circuit_breaker ( ) . record_success ( ) ;
722+ }
723+ let mut response = Response :: new ( Body :: from ( body) ) ;
724+ * response. status_mut ( ) = status;
725+ if let Some ( ct) = content_type {
726+ response. headers_mut ( ) . insert ( CONTENT_TYPE , ct) ;
727+ }
728+ response
729+ }
730+ Err ( e) => {
731+ worker. circuit_breaker ( ) . record_failure ( ) ;
732+ (
733+ StatusCode :: INTERNAL_SERVER_ERROR ,
734+ format ! ( "Failed to read response: {}" , e) ,
735+ )
736+ . into_response ( )
737+ }
738+ }
739+ } else {
740+ // Streaming response - record success when stream starts
741+ if status. is_success ( ) {
742+ worker. circuit_breaker ( ) . record_success ( ) ;
743+ }
744+ let stream = resp. bytes_stream ( ) ;
745+ let ( tx, rx) = mpsc:: unbounded_channel ( ) ;
746+ tokio:: spawn ( async move {
747+ let mut s = stream;
748+ while let Some ( chunk) = s. next ( ) . await {
749+ match chunk {
750+ Ok ( bytes) => {
751+ if tx. send ( Ok ( bytes) ) . is_err ( ) {
752+ break ;
753+ }
754+ }
755+ Err ( e) => {
756+ let _ = tx. send ( Err ( format ! ( "Stream error: {}" , e) ) ) ;
757+ break ;
758+ }
759+ }
760+ }
761+ } ) ;
762+ let mut response =
763+ Response :: new ( Body :: from_stream ( UnboundedReceiverStream :: new ( rx) ) ) ;
764+ * response. status_mut ( ) = status;
765+ response
766+ . headers_mut ( )
767+ . insert ( CONTENT_TYPE , HeaderValue :: from_static ( "text/event-stream" ) ) ;
768+ response
769+ }
770+ }
771+ } ,
772+ |res, _attempt| is_retryable_status ( res. status ( ) ) ,
773+ |delay, attempt| {
774+ Metrics :: record_worker_retry (
677775 metrics_labels:: BACKEND_EXTERNAL ,
678- metrics_labels:: CONNECTION_HTTP ,
679- model,
680776 metrics_labels:: ENDPOINT_CHAT ,
681- metrics_labels:: ERROR_BACKEND ,
682777 ) ;
683- return (
684- StatusCode :: SERVICE_UNAVAILABLE ,
685- format ! ( "Failed to contact upstream: {}" , e) ,
686- )
687- . into_response ( ) ;
688- }
689- } ;
778+ Metrics :: record_worker_retry_backoff ( attempt, delay) ;
779+ } ,
780+ || {
781+ Metrics :: record_worker_retries_exhausted (
782+ metrics_labels:: BACKEND_EXTERNAL ,
783+ metrics_labels:: ENDPOINT_CHAT ,
784+ ) ;
785+ } ,
786+ )
787+ . await ;
690788
691- let status = StatusCode :: from_u16 ( resp. status ( ) . as_u16 ( ) )
692- . unwrap_or ( StatusCode :: INTERNAL_SERVER_ERROR ) ;
693-
694- if !ctx. is_streaming ( ) {
695- let content_type = resp. headers ( ) . get ( CONTENT_TYPE ) . cloned ( ) ;
696- match resp. bytes ( ) . await {
697- Ok ( body) => {
698- worker. circuit_breaker ( ) . record_success ( ) ;
699- Metrics :: record_router_duration (
700- metrics_labels:: ROUTER_OPENAI ,
701- metrics_labels:: BACKEND_EXTERNAL ,
702- metrics_labels:: CONNECTION_HTTP ,
703- model,
704- metrics_labels:: ENDPOINT_CHAT ,
705- start. elapsed ( ) ,
706- ) ;
707- let mut response = Response :: new ( Body :: from ( body) ) ;
708- * response. status_mut ( ) = status;
709- if let Some ( ct) = content_type {
710- response. headers_mut ( ) . insert ( CONTENT_TYPE , ct) ;
711- }
712- response
713- }
714- Err ( e) => {
715- worker. circuit_breaker ( ) . record_failure ( ) ;
716- Metrics :: record_router_error (
717- metrics_labels:: ROUTER_OPENAI ,
718- metrics_labels:: BACKEND_EXTERNAL ,
719- metrics_labels:: CONNECTION_HTTP ,
720- model,
721- metrics_labels:: ENDPOINT_CHAT ,
722- metrics_labels:: ERROR_BACKEND ,
723- ) ;
724- (
725- StatusCode :: INTERNAL_SERVER_ERROR ,
726- format ! ( "Failed to read response: {}" , e) ,
727- )
728- . into_response ( )
729- }
730- }
731- } else {
732- // For streaming, record duration at start since we can't track completion
789+ // Record duration/error metrics after retry completes
790+ if response. status ( ) . is_success ( ) {
733791 Metrics :: record_router_duration (
734792 metrics_labels:: ROUTER_OPENAI ,
735793 metrics_labels:: BACKEND_EXTERNAL ,
@@ -738,31 +796,18 @@ impl crate::routers::RouterTrait for OpenAIRouter {
738796 metrics_labels:: ENDPOINT_CHAT ,
739797 start. elapsed ( ) ,
740798 ) ;
741- let stream = resp. bytes_stream ( ) ;
742- let ( tx, rx) = mpsc:: unbounded_channel ( ) ;
743- tokio:: spawn ( async move {
744- let mut s = stream;
745- while let Some ( chunk) = s. next ( ) . await {
746- match chunk {
747- Ok ( bytes) => {
748- if tx. send ( Ok ( bytes) ) . is_err ( ) {
749- break ;
750- }
751- }
752- Err ( e) => {
753- let _ = tx. send ( Err ( format ! ( "Stream error: {}" , e) ) ) ;
754- break ;
755- }
756- }
757- }
758- } ) ;
759- let mut response = Response :: new ( Body :: from_stream ( UnboundedReceiverStream :: new ( rx) ) ) ;
760- * response. status_mut ( ) = status;
761- response
762- . headers_mut ( )
763- . insert ( CONTENT_TYPE , HeaderValue :: from_static ( "text/event-stream" ) ) ;
764- response
799+ } else {
800+ Metrics :: record_router_error (
801+ metrics_labels:: ROUTER_OPENAI ,
802+ metrics_labels:: BACKEND_EXTERNAL ,
803+ metrics_labels:: CONNECTION_HTTP ,
804+ model,
805+ metrics_labels:: ENDPOINT_CHAT ,
806+ metrics_labels:: ERROR_BACKEND ,
807+ ) ;
765808 }
809+
810+ response
766811 }
767812
768813 async fn route_responses (
0 commit comments