@@ -127,6 +127,7 @@ struct Inner {
127127 // need serialization so concurrent register/remove operations do not
128128 // overwrite each other's copy-on-write updates.
129129 sessions_write_lock : Mutex < ( ) > ,
130+ disconnected : std:: sync:: RwLock < Option < String > > ,
130131 session_id : std:: sync:: RwLock < Option < String > > ,
131132 reader_task : tokio:: task:: JoinHandle < ( ) > ,
132133}
@@ -158,6 +159,8 @@ pub enum ExecServerError {
158159 InitializeTimedOut { timeout : Duration } ,
159160 #[ error( "exec-server transport closed" ) ]
160161 Closed ,
162+ #[ error( "{0}" ) ]
163+ Disconnected ( String ) ,
161164 #[ error( "failed to serialize or deserialize exec-server JSON: {0}" ) ]
162165 Json ( #[ from] serde_json:: Error ) ,
163166 #[ error( "exec-server protocol error: {0}" ) ]
@@ -233,127 +236,85 @@ impl ExecServerClient {
233236 }
234237
235238 pub async fn exec ( & self , params : ExecParams ) -> Result < ExecResponse , ExecServerError > {
236- self . inner
237- . client
238- . call ( EXEC_METHOD , & params)
239- . await
240- . map_err ( Into :: into)
239+ self . call ( EXEC_METHOD , & params) . await
241240 }
242241
243242 pub async fn read ( & self , params : ReadParams ) -> Result < ReadResponse , ExecServerError > {
244- self . inner
245- . client
246- . call ( EXEC_READ_METHOD , & params)
247- . await
248- . map_err ( Into :: into)
243+ self . call ( EXEC_READ_METHOD , & params) . await
249244 }
250245
251246 pub async fn write (
252247 & self ,
253248 process_id : & ProcessId ,
254249 chunk : Vec < u8 > ,
255250 ) -> Result < WriteResponse , ExecServerError > {
256- self . inner
257- . client
258- . call (
259- EXEC_WRITE_METHOD ,
260- & WriteParams {
261- process_id : process_id. clone ( ) ,
262- chunk : chunk. into ( ) ,
263- } ,
264- )
265- . await
266- . map_err ( Into :: into)
251+ self . call (
252+ EXEC_WRITE_METHOD ,
253+ & WriteParams {
254+ process_id : process_id. clone ( ) ,
255+ chunk : chunk. into ( ) ,
256+ } ,
257+ )
258+ . await
267259 }
268260
269261 pub async fn terminate (
270262 & self ,
271263 process_id : & ProcessId ,
272264 ) -> Result < TerminateResponse , ExecServerError > {
273- self . inner
274- . client
275- . call (
276- EXEC_TERMINATE_METHOD ,
277- & TerminateParams {
278- process_id : process_id. clone ( ) ,
279- } ,
280- )
281- . await
282- . map_err ( Into :: into)
265+ self . call (
266+ EXEC_TERMINATE_METHOD ,
267+ & TerminateParams {
268+ process_id : process_id. clone ( ) ,
269+ } ,
270+ )
271+ . await
283272 }
284273
285274 pub async fn fs_read_file (
286275 & self ,
287276 params : FsReadFileParams ,
288277 ) -> Result < FsReadFileResponse , ExecServerError > {
289- self . inner
290- . client
291- . call ( FS_READ_FILE_METHOD , & params)
292- . await
293- . map_err ( Into :: into)
278+ self . call ( FS_READ_FILE_METHOD , & params) . await
294279 }
295280
296281 pub async fn fs_write_file (
297282 & self ,
298283 params : FsWriteFileParams ,
299284 ) -> Result < FsWriteFileResponse , ExecServerError > {
300- self . inner
301- . client
302- . call ( FS_WRITE_FILE_METHOD , & params)
303- . await
304- . map_err ( Into :: into)
285+ self . call ( FS_WRITE_FILE_METHOD , & params) . await
305286 }
306287
307288 pub async fn fs_create_directory (
308289 & self ,
309290 params : FsCreateDirectoryParams ,
310291 ) -> Result < FsCreateDirectoryResponse , ExecServerError > {
311- self . inner
312- . client
313- . call ( FS_CREATE_DIRECTORY_METHOD , & params)
314- . await
315- . map_err ( Into :: into)
292+ self . call ( FS_CREATE_DIRECTORY_METHOD , & params) . await
316293 }
317294
318295 pub async fn fs_get_metadata (
319296 & self ,
320297 params : FsGetMetadataParams ,
321298 ) -> Result < FsGetMetadataResponse , ExecServerError > {
322- self . inner
323- . client
324- . call ( FS_GET_METADATA_METHOD , & params)
325- . await
326- . map_err ( Into :: into)
299+ self . call ( FS_GET_METADATA_METHOD , & params) . await
327300 }
328301
329302 pub async fn fs_read_directory (
330303 & self ,
331304 params : FsReadDirectoryParams ,
332305 ) -> Result < FsReadDirectoryResponse , ExecServerError > {
333- self . inner
334- . client
335- . call ( FS_READ_DIRECTORY_METHOD , & params)
336- . await
337- . map_err ( Into :: into)
306+ self . call ( FS_READ_DIRECTORY_METHOD , & params) . await
338307 }
339308
340309 pub async fn fs_remove (
341310 & self ,
342311 params : FsRemoveParams ,
343312 ) -> Result < FsRemoveResponse , ExecServerError > {
344- self . inner
345- . client
346- . call ( FS_REMOVE_METHOD , & params)
347- . await
348- . map_err ( Into :: into)
313+ self . call ( FS_REMOVE_METHOD , & params) . await
349314 }
350315
351316 pub async fn fs_copy ( & self , params : FsCopyParams ) -> Result < FsCopyResponse , ExecServerError > {
352- self . inner
353- . client
354- . call ( FS_COPY_METHOD , & params)
355- . await
356- . map_err ( Into :: into)
317+ self . call ( FS_COPY_METHOD , & params) . await
357318 }
358319
359320 pub ( crate ) async fn register_session (
@@ -398,7 +359,7 @@ impl ExecServerClient {
398359 && let Err ( err) =
399360 handle_server_notification ( & inner, notification) . await
400361 {
401- fail_all_sessions (
362+ mark_disconnected (
402363 & inner,
403364 format ! ( "exec-server notification handling failed: {err}" ) ,
404365 )
@@ -408,7 +369,7 @@ impl ExecServerClient {
408369 }
409370 RpcClientEvent :: Disconnected { reason } => {
410371 if let Some ( inner) = weak. upgrade ( ) {
411- fail_all_sessions ( & inner, disconnected_message ( reason. as_deref ( ) ) )
372+ mark_disconnected ( & inner, disconnected_message ( reason. as_deref ( ) ) )
412373 . await ;
413374 }
414375 return ;
@@ -421,6 +382,7 @@ impl ExecServerClient {
421382 client : rpc_client,
422383 sessions : ArcSwap :: from_pointee ( HashMap :: new ( ) ) ,
423384 sessions_write_lock : Mutex :: new ( ( ) ) ,
385+ disconnected : std:: sync:: RwLock :: new ( None ) ,
424386 session_id : std:: sync:: RwLock :: new ( None ) ,
425387 reader_task,
426388 }
@@ -438,6 +400,30 @@ impl ExecServerClient {
438400 . await
439401 . map_err ( ExecServerError :: Json )
440402 }
403+
404+ async fn call < P , T > ( & self , method : & str , params : & P ) -> Result < T , ExecServerError >
405+ where
406+ P : serde:: Serialize ,
407+ T : serde:: de:: DeserializeOwned ,
408+ {
409+ if let Some ( error) = self . inner . disconnected_error ( ) {
410+ return Err ( error) ;
411+ }
412+
413+ match self . inner . client . call ( method, params) . await {
414+ Ok ( response) => Ok ( response) ,
415+ Err ( error) => {
416+ let error = ExecServerError :: from ( error) ;
417+ if is_transport_closed_error ( & error) {
418+ let message = disconnected_message ( /*reason*/ None ) ;
419+ let message = mark_disconnected ( & self . inner , message) . await ;
420+ Err ( ExecServerError :: Disconnected ( message) )
421+ } else {
422+ Err ( error)
423+ }
424+ }
425+ }
426+ }
441427}
442428
443429impl From < RpcCallError > for ExecServerError {
@@ -573,6 +559,26 @@ impl Session {
573559}
574560
575561impl Inner {
562+ fn disconnected_error ( & self ) -> Option < ExecServerError > {
563+ self . disconnected
564+ . read ( )
565+ . unwrap_or_else ( std:: sync:: PoisonError :: into_inner)
566+ . clone ( )
567+ . map ( ExecServerError :: Disconnected )
568+ }
569+
570+ fn set_disconnected ( & self , message : String ) -> Option < String > {
571+ let mut disconnected = self
572+ . disconnected
573+ . write ( )
574+ . unwrap_or_else ( std:: sync:: PoisonError :: into_inner) ;
575+ if disconnected. is_some ( ) {
576+ return None ;
577+ }
578+ * disconnected = Some ( message. clone ( ) ) ;
579+ Some ( message)
580+ }
581+
576582 fn get_session ( & self , process_id : & ProcessId ) -> Option < Arc < SessionState > > {
577583 self . sessions . load ( ) . get ( process_id) . cloned ( )
578584 }
@@ -583,6 +589,9 @@ impl Inner {
583589 session : Arc < SessionState > ,
584590 ) -> Result < ( ) , ExecServerError > {
585591 let _sessions_write_guard = self . sessions_write_lock . lock ( ) . await ;
592+ if let Some ( error) = self . disconnected_error ( ) {
593+ return Err ( error) ;
594+ }
586595 let sessions = self . sessions . load ( ) ;
587596 if sessions. contains_key ( process_id) {
588597 return Err ( ExecServerError :: Protocol ( format ! (
@@ -623,14 +632,30 @@ fn disconnected_message(reason: Option<&str>) -> String {
623632}
624633
625634fn is_transport_closed_error ( error : & ExecServerError ) -> bool {
626- matches ! ( error, ExecServerError :: Closed )
627- || matches ! (
628- error,
629- ExecServerError :: Server {
630- code: -32000 ,
631- message,
632- } if message == "JSON-RPC transport closed"
633- )
635+ matches ! (
636+ error,
637+ ExecServerError :: Closed | ExecServerError :: Disconnected ( _)
638+ ) || matches ! (
639+ error,
640+ ExecServerError :: Server {
641+ code: -32000 ,
642+ message,
643+ } if message == "JSON-RPC transport closed"
644+ )
645+ }
646+
647+ async fn mark_disconnected ( inner : & Arc < Inner > , message : String ) -> String {
648+ if let Some ( message) = inner. set_disconnected ( message. clone ( ) ) {
649+ fail_all_sessions ( inner, message. clone ( ) ) . await ;
650+ message
651+ } else {
652+ inner
653+ . disconnected
654+ . read ( )
655+ . unwrap_or_else ( std:: sync:: PoisonError :: into_inner)
656+ . clone ( )
657+ . unwrap_or ( message)
658+ }
634659}
635660
636661async fn fail_all_sessions ( inner : & Arc < Inner > , message : String ) {
0 commit comments