@@ -840,6 +840,10 @@ impl SpannIndexWriter {
840840 ) ;
841841 return Ok ( ( ) ) ;
842842 }
843+ // Don't reassign if concurrently deleted.
844+ if self . is_head_deleted ( prev_head_id as usize ) . await ? {
845+ return Ok ( ( ) ) ;
846+ }
843847 // RNG query to find the nearest heads.
844848 let ( nearest_head_ids, _, nearest_head_embeddings) = self . rng_query ( doc_embedding) . await ?;
845849 // Don't reassign if empty.
@@ -934,23 +938,20 @@ impl SpannIndexWriter {
934938 new_head_embeddings : & [ Option < & Vec < f32 > > ] ,
935939 old_head_embedding : & [ f32 ] ,
936940 ) -> Result < ( ) , SpannIndexWriterError > {
937- // Get posting list of each neighbour and check for reassignment criteria.
938- let doc_offset_ids;
939- let doc_versions;
940- let doc_embeddings;
941- {
942- // TODO(Sanket): Check if head is deleted, can happen if another concurrent thread
943- // deletes it.
944- ( doc_offset_ids, doc_versions, doc_embeddings) = self
945- . posting_list_writer
946- . get_owned :: < u32 , & SpannPostingList < ' _ > > ( "" , head_id as u32 )
947- . await
948- . map_err ( |e| {
949- tracing:: error!( "Error getting posting list for head {}: {}" , head_id, e) ;
950- SpannIndexWriterError :: PostingListGetError ( e)
951- } ) ?
952- . ok_or ( SpannIndexWriterError :: PostingListNotFound ) ?;
941+ // Head got concurrrently deleted so abort reassignment.
942+ if self . is_head_deleted ( head_id) . await ? {
943+ return Ok ( ( ) ) ;
953944 }
945+ // Get posting list of each neighbour and check for reassignment criteria.
946+ let ( doc_offset_ids, doc_versions, doc_embeddings) = self
947+ . posting_list_writer
948+ . get_owned :: < u32 , & SpannPostingList < ' _ > > ( "" , head_id as u32 )
949+ . await
950+ . map_err ( |e| {
951+ tracing:: error!( "Error getting posting list for head {}: {}" , head_id, e) ;
952+ SpannIndexWriterError :: PostingListGetError ( e)
953+ } ) ?
954+ . ok_or ( SpannIndexWriterError :: PostingListNotFound ) ?;
954955 for ( index, doc_offset_id) in doc_offset_ids. iter ( ) . enumerate ( ) {
955956 if assigned_ids. contains ( doc_offset_id)
956957 || self
@@ -959,6 +960,10 @@ impl SpannIndexWriter {
959960 {
960961 continue ;
961962 }
963+ if self . is_head_deleted ( head_id) . await ? {
964+ // Head got concurrently deleted so abort reassignment.
965+ return Ok ( ( ) ) ;
966+ }
962967 let distance_function: DistanceFunction = self . params . space . clone ( ) . into ( ) ;
963968 let distance_from_curr_center = distance_function. distance (
964969 & doc_embeddings[ index * self . dimensionality ..( index + 1 ) * self . dimensionality ] ,
@@ -1437,7 +1442,6 @@ impl SpannIndexWriter {
14371442 doc_versions : & [ version] ,
14381443 doc_embeddings : embeddings,
14391444 } ;
1440- let _write_guard = self . posting_list_partitioned_mutex . lock ( & id) . await ;
14411445 self . posting_list_writer
14421446 . set ( "" , next_id, & posting_list)
14431447 . await
0 commit comments