@@ -126,6 +126,7 @@ struct Inner {
126126 // need serialization so concurrent register/remove operations do not
127127 // overwrite each other's copy-on-write updates.
128128 sessions_write_lock : Mutex < ( ) > ,
129+ disconnected : std:: sync:: RwLock < Option < String > > ,
129130 session_id : std:: sync:: RwLock < Option < String > > ,
130131 reader_task : tokio:: task:: JoinHandle < ( ) > ,
131132}
@@ -157,6 +158,8 @@ pub enum ExecServerError {
157158 InitializeTimedOut { timeout : Duration } ,
158159 #[ error( "exec-server transport closed" ) ]
159160 Closed ,
161+ #[ error( "{0}" ) ]
162+ Disconnected ( String ) ,
160163 #[ error( "failed to serialize or deserialize exec-server JSON: {0}" ) ]
161164 Json ( #[ from] serde_json:: Error ) ,
162165 #[ error( "exec-server protocol error: {0}" ) ]
@@ -232,127 +235,85 @@ impl ExecServerClient {
232235 }
233236
234237 pub async fn exec ( & self , params : ExecParams ) -> Result < ExecResponse , ExecServerError > {
235- self . inner
236- . client
237- . call ( EXEC_METHOD , & params)
238- . await
239- . map_err ( Into :: into)
238+ self . call ( EXEC_METHOD , & params) . await
240239 }
241240
242241 pub async fn read ( & self , params : ReadParams ) -> Result < ReadResponse , ExecServerError > {
243- self . inner
244- . client
245- . call ( EXEC_READ_METHOD , & params)
246- . await
247- . map_err ( Into :: into)
242+ self . call ( EXEC_READ_METHOD , & params) . await
248243 }
249244
250245 pub async fn write (
251246 & self ,
252247 process_id : & ProcessId ,
253248 chunk : Vec < u8 > ,
254249 ) -> Result < WriteResponse , ExecServerError > {
255- self . inner
256- . client
257- . call (
258- EXEC_WRITE_METHOD ,
259- & WriteParams {
260- process_id : process_id. clone ( ) ,
261- chunk : chunk. into ( ) ,
262- } ,
263- )
264- . await
265- . map_err ( Into :: into)
250+ self . call (
251+ EXEC_WRITE_METHOD ,
252+ & WriteParams {
253+ process_id : process_id. clone ( ) ,
254+ chunk : chunk. into ( ) ,
255+ } ,
256+ )
257+ . await
266258 }
267259
268260 pub async fn terminate (
269261 & self ,
270262 process_id : & ProcessId ,
271263 ) -> Result < TerminateResponse , ExecServerError > {
272- self . inner
273- . client
274- . call (
275- EXEC_TERMINATE_METHOD ,
276- & TerminateParams {
277- process_id : process_id. clone ( ) ,
278- } ,
279- )
280- . await
281- . map_err ( Into :: into)
264+ self . call (
265+ EXEC_TERMINATE_METHOD ,
266+ & TerminateParams {
267+ process_id : process_id. clone ( ) ,
268+ } ,
269+ )
270+ . await
282271 }
283272
284273 pub async fn fs_read_file (
285274 & self ,
286275 params : FsReadFileParams ,
287276 ) -> Result < FsReadFileResponse , ExecServerError > {
288- self . inner
289- . client
290- . call ( FS_READ_FILE_METHOD , & params)
291- . await
292- . map_err ( Into :: into)
277+ self . call ( FS_READ_FILE_METHOD , & params) . await
293278 }
294279
295280 pub async fn fs_write_file (
296281 & self ,
297282 params : FsWriteFileParams ,
298283 ) -> Result < FsWriteFileResponse , ExecServerError > {
299- self . inner
300- . client
301- . call ( FS_WRITE_FILE_METHOD , & params)
302- . await
303- . map_err ( Into :: into)
284+ self . call ( FS_WRITE_FILE_METHOD , & params) . await
304285 }
305286
306287 pub async fn fs_create_directory (
307288 & self ,
308289 params : FsCreateDirectoryParams ,
309290 ) -> Result < FsCreateDirectoryResponse , ExecServerError > {
310- self . inner
311- . client
312- . call ( FS_CREATE_DIRECTORY_METHOD , & params)
313- . await
314- . map_err ( Into :: into)
291+ self . call ( FS_CREATE_DIRECTORY_METHOD , & params) . await
315292 }
316293
317294 pub async fn fs_get_metadata (
318295 & self ,
319296 params : FsGetMetadataParams ,
320297 ) -> Result < FsGetMetadataResponse , ExecServerError > {
321- self . inner
322- . client
323- . call ( FS_GET_METADATA_METHOD , & params)
324- . await
325- . map_err ( Into :: into)
298+ self . call ( FS_GET_METADATA_METHOD , & params) . await
326299 }
327300
328301 pub async fn fs_read_directory (
329302 & self ,
330303 params : FsReadDirectoryParams ,
331304 ) -> Result < FsReadDirectoryResponse , ExecServerError > {
332- self . inner
333- . client
334- . call ( FS_READ_DIRECTORY_METHOD , & params)
335- . await
336- . map_err ( Into :: into)
305+ self . call ( FS_READ_DIRECTORY_METHOD , & params) . await
337306 }
338307
339308 pub async fn fs_remove (
340309 & self ,
341310 params : FsRemoveParams ,
342311 ) -> Result < FsRemoveResponse , ExecServerError > {
343- self . inner
344- . client
345- . call ( FS_REMOVE_METHOD , & params)
346- . await
347- . map_err ( Into :: into)
312+ self . call ( FS_REMOVE_METHOD , & params) . await
348313 }
349314
350315 pub async fn fs_copy ( & self , params : FsCopyParams ) -> Result < FsCopyResponse , ExecServerError > {
351- self . inner
352- . client
353- . call ( FS_COPY_METHOD , & params)
354- . await
355- . map_err ( Into :: into)
316+ self . call ( FS_COPY_METHOD , & params) . await
356317 }
357318
358319 pub ( crate ) async fn register_session (
@@ -397,7 +358,7 @@ impl ExecServerClient {
397358 && let Err ( err) =
398359 handle_server_notification ( & inner, notification) . await
399360 {
400- fail_all_sessions (
361+ mark_disconnected (
401362 & inner,
402363 format ! ( "exec-server notification handling failed: {err}" ) ,
403364 )
@@ -407,7 +368,7 @@ impl ExecServerClient {
407368 }
408369 RpcClientEvent :: Disconnected { reason } => {
409370 if let Some ( inner) = weak. upgrade ( ) {
410- fail_all_sessions ( & inner, disconnected_message ( reason. as_deref ( ) ) )
371+ mark_disconnected ( & inner, disconnected_message ( reason. as_deref ( ) ) )
411372 . await ;
412373 }
413374 return ;
@@ -420,6 +381,7 @@ impl ExecServerClient {
420381 client : rpc_client,
421382 sessions : ArcSwap :: from_pointee ( HashMap :: new ( ) ) ,
422383 sessions_write_lock : Mutex :: new ( ( ) ) ,
384+ disconnected : std:: sync:: RwLock :: new ( None ) ,
423385 session_id : std:: sync:: RwLock :: new ( None ) ,
424386 reader_task,
425387 }
@@ -437,6 +399,30 @@ impl ExecServerClient {
437399 . await
438400 . map_err ( ExecServerError :: Json )
439401 }
402+
403+ async fn call < P , T > ( & self , method : & str , params : & P ) -> Result < T , ExecServerError >
404+ where
405+ P : serde:: Serialize ,
406+ T : serde:: de:: DeserializeOwned ,
407+ {
408+ if let Some ( error) = self . inner . disconnected_error ( ) {
409+ return Err ( error) ;
410+ }
411+
412+ match self . inner . client . call ( method, params) . await {
413+ Ok ( response) => Ok ( response) ,
414+ Err ( error) => {
415+ let error = ExecServerError :: from ( error) ;
416+ if is_transport_closed_error ( & error) {
417+ let message = disconnected_message ( /*reason*/ None ) ;
418+ let message = mark_disconnected ( & self . inner , message) . await ;
419+ Err ( ExecServerError :: Disconnected ( message) )
420+ } else {
421+ Err ( error)
422+ }
423+ }
424+ }
425+ }
440426}
441427
442428impl From < RpcCallError > for ExecServerError {
@@ -573,6 +559,26 @@ impl Session {
573559}
574560
575561impl Inner {
562+ fn disconnected_error ( & self ) -> Option < ExecServerError > {
563+ self . disconnected
564+ . read ( )
565+ . unwrap_or_else ( std:: sync:: PoisonError :: into_inner)
566+ . clone ( )
567+ . map ( ExecServerError :: Disconnected )
568+ }
569+
570+ fn set_disconnected ( & self , message : String ) -> Option < String > {
571+ let mut disconnected = self
572+ . disconnected
573+ . write ( )
574+ . unwrap_or_else ( std:: sync:: PoisonError :: into_inner) ;
575+ if disconnected. is_some ( ) {
576+ return None ;
577+ }
578+ * disconnected = Some ( message. clone ( ) ) ;
579+ Some ( message)
580+ }
581+
576582 fn get_session ( & self , process_id : & ProcessId ) -> Option < Arc < SessionState > > {
577583 self . sessions . load ( ) . get ( process_id) . cloned ( )
578584 }
@@ -583,6 +589,9 @@ impl Inner {
583589 session : Arc < SessionState > ,
584590 ) -> Result < ( ) , ExecServerError > {
585591 let _sessions_write_guard = self . sessions_write_lock . lock ( ) . await ;
592+ if let Some ( error) = self . disconnected_error ( ) {
593+ return Err ( error) ;
594+ }
586595 let sessions = self . sessions . load ( ) ;
587596 if sessions. contains_key ( process_id) {
588597 return Err ( ExecServerError :: Protocol ( format ! (
@@ -623,14 +632,30 @@ fn disconnected_message(reason: Option<&str>) -> String {
623632}
624633
625634fn is_transport_closed_error ( error : & ExecServerError ) -> bool {
626- matches ! ( error, ExecServerError :: Closed )
627- || matches ! (
628- error,
629- ExecServerError :: Server {
630- code: -32000 ,
631- message,
632- } if message == "JSON-RPC transport closed"
633- )
635+ matches ! (
636+ error,
637+ ExecServerError :: Closed | ExecServerError :: Disconnected ( _)
638+ ) || matches ! (
639+ error,
640+ ExecServerError :: Server {
641+ code: -32000 ,
642+ message,
643+ } if message == "JSON-RPC transport closed"
644+ )
645+ }
646+
647+ async fn mark_disconnected ( inner : & Arc < Inner > , message : String ) -> String {
648+ if let Some ( message) = inner. set_disconnected ( message. clone ( ) ) {
649+ fail_all_sessions ( inner, message. clone ( ) ) . await ;
650+ message
651+ } else {
652+ inner
653+ . disconnected
654+ . read ( )
655+ . unwrap_or_else ( std:: sync:: PoisonError :: into_inner)
656+ . clone ( )
657+ . unwrap_or ( message)
658+ }
634659}
635660
636661async fn fail_all_sessions ( inner : & Arc < Inner > , message : String ) {
0 commit comments