Skip to content

Commit 561c851

Browse files
committed
Finish impl
1 parent 3481a76 commit 561c851

File tree

2 files changed

+72
-54
lines changed

2 files changed

+72
-54
lines changed

rust/index/src/spann/types.rs

Lines changed: 66 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1437,8 +1437,8 @@ impl SpannIndexWriter {
14371437
doc_versions: &[version],
14381438
doc_embeddings: embeddings,
14391439
};
1440-
let write_guard = self.posting_list_writer.lock().await;
1441-
write_guard
1440+
let _write_guard = self.posting_list_partitioned_mutex.lock(&id).await;
1441+
self.posting_list_writer
14421442
.set("", next_id, &posting_list)
14431443
.await
14441444
.map_err(|e| {
@@ -1663,12 +1663,13 @@ impl SpannIndexWriter {
16631663
let mut target_embedding = vec![];
16641664
let mut target_head = 0;
16651665
{
1666-
let pl_guard = self.posting_list_writer.lock().await;
1666+
// TODO(Sanket): Add a lock on the head here if this is called concurrently.
16671667
// If head is concurrently deleted then skip.
16681668
if self.is_head_deleted(head_id).await? {
16691669
return Ok(());
16701670
}
1671-
(doc_offset_ids, doc_versions, doc_embeddings) = pl_guard
1671+
(doc_offset_ids, doc_versions, doc_embeddings) = self
1672+
.posting_list_writer
16721673
.get_owned::<u32, &SpannPostingList<'_>>("", head_id as u32)
16731674
.await
16741675
.map_err(|e| {
@@ -1687,7 +1688,7 @@ impl SpannIndexWriter {
16871688
doc_versions: &doc_versions,
16881689
doc_embeddings: &doc_embeddings,
16891690
};
1690-
pl_guard
1691+
self.posting_list_writer
16911692
.set("", head_id as u32, &posting_list)
16921693
.await
16931694
.map_err(|e| {
@@ -1725,14 +1726,15 @@ impl SpannIndexWriter {
17251726
if nearest_head_id == head_id {
17261727
continue;
17271728
}
1728-
// TODO(Sanket): If and when the lock is more fine grained, then
1729+
// TODO(Sanket): If and when GC is concurrent, then
17291730
// need to acquire a lock on the nearest_head_id here.
17301731
// TODO(Sanket): Also need to check if the head is deleted concurrently then.
17311732
let (
17321733
nearest_head_doc_offset_ids,
17331734
nearest_head_doc_versions,
17341735
nearest_head_doc_embeddings,
1735-
) = pl_guard
1736+
) = self
1737+
.posting_list_writer
17361738
.get_owned::<u32, &SpannPostingList<'_>>("", nearest_head_id as u32)
17371739
.await
17381740
.map_err(|e| {
@@ -1771,7 +1773,7 @@ impl SpannIndexWriter {
17711773
doc_embeddings: &doc_embeddings,
17721774
};
17731775
if target_cluster_len > source_cluster_len {
1774-
pl_guard
1776+
self.posting_list_writer
17751777
.set("", nearest_head_id as u32, &merged_posting_list)
17761778
.await
17771779
.map_err(|e| {
@@ -1795,7 +1797,7 @@ impl SpannIndexWriter {
17951797
.num_heads_deleted
17961798
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
17971799
} else {
1798-
pl_guard
1800+
self.posting_list_writer
17991801
.set("", head_id as u32, &merged_posting_list)
18001802
.await
18011803
.map_err(|e| {
@@ -2214,7 +2216,7 @@ impl SpannIndexWriter {
22142216
)];
22152217
let pl_flusher = {
22162218
let stopwatch = Stopwatch::new(&self.metrics.pl_commit_latency, attribute);
2217-
let pl_writer_clone = self.posting_list_writer.lock().await.clone();
2219+
let pl_writer_clone = self.posting_list_writer.clone();
22182220
let pl_flusher = pl_writer_clone
22192221
.commit::<u32, &SpannPostingList<'_>>()
22202222
.await
@@ -2889,8 +2891,8 @@ mod tests {
28892891
}
28902892
{
28912893
// Posting list should have 100 points.
2892-
let pl_read_guard = writer.posting_list_writer.lock().await;
2893-
let pl = pl_read_guard
2894+
let pl = writer
2895+
.posting_list_writer
28942896
.get_owned::<u32, &SpannPostingList<'_>>("", 1)
28952897
.await
28962898
.expect("Error getting posting list")
@@ -2920,13 +2922,14 @@ mod tests {
29202922
}
29212923
{
29222924
// Posting list should have 100 points.
2923-
let pl_read_guard = writer.posting_list_writer.lock().await;
2924-
let pl1 = pl_read_guard
2925+
let pl1 = writer
2926+
.posting_list_writer
29252927
.get_owned::<u32, &SpannPostingList<'_>>("", emb_1_id)
29262928
.await
29272929
.expect("Error getting posting list")
29282930
.unwrap();
2929-
let pl2 = pl_read_guard
2931+
let pl2 = writer
2932+
.posting_list_writer
29302933
.get_owned::<u32, &SpannPostingList<'_>>("", emb_2_id)
29312934
.await
29322935
.expect("Error getting posting list")
@@ -2980,16 +2983,17 @@ mod tests {
29802983
}
29812984
{
29822985
// Posting list should have 100 points.
2983-
let pl_read_guard = writer.posting_list_writer.lock().await;
2984-
let pl = pl_read_guard
2986+
let pl = writer
2987+
.posting_list_writer
29852988
.get_owned::<u32, &SpannPostingList<'_>>("", emb_1_id)
29862989
.await
29872990
.expect("Error getting posting list")
29882991
.unwrap();
29892992
assert_eq!(pl.0.len(), 100);
29902993
assert_eq!(pl.1.len(), 100);
29912994
assert_eq!(pl.2.len(), 200);
2992-
let pl = pl_read_guard
2995+
let pl = writer
2996+
.posting_list_writer
29932997
.get_owned::<u32, &SpannPostingList<'_>>("", emb_2_id)
29942998
.await
29952999
.expect("Error getting posting list")
@@ -3078,7 +3082,6 @@ mod tests {
30783082
.expect("Error adding to hnsw index");
30793083
}
30803084
{
3081-
let pl_guard = writer.posting_list_writer.lock().await;
30823085
let mut doc_offset_ids = vec![0u32; 100];
30833086
let mut doc_versions = vec![0; 100];
30843087
let mut doc_embeddings = vec![0.0; 200];
@@ -3094,7 +3097,8 @@ mod tests {
30943097
doc_versions: &doc_versions,
30953098
doc_embeddings: &doc_embeddings,
30963099
};
3097-
pl_guard
3100+
writer
3101+
.posting_list_writer
30983102
.set("", 1, &pl)
30993103
.await
31003104
.expect("Error writing to posting list");
@@ -3109,7 +3113,8 @@ mod tests {
31093113
doc_versions: &doc_versions,
31103114
doc_embeddings: &doc_embeddings,
31113115
};
3112-
pl_guard
3116+
writer
3117+
.posting_list_writer
31133118
.set("", 2, &pl)
31143119
.await
31153120
.expect("Error writing to posting list");
@@ -3148,16 +3153,17 @@ mod tests {
31483153
}
31493154
{
31503155
// The posting lists should not be changed at all.
3151-
let pl_guard = writer.posting_list_writer.lock().await;
3152-
let pl = pl_guard
3156+
let pl = writer
3157+
.posting_list_writer
31533158
.get_owned::<u32, &SpannPostingList<'_>>("", 1)
31543159
.await
31553160
.expect("Error getting posting list")
31563161
.unwrap();
31573162
assert_eq!(pl.0.len(), 100);
31583163
assert_eq!(pl.1.len(), 100);
31593164
assert_eq!(pl.2.len(), 200);
3160-
let pl = pl_guard
3165+
let pl = writer
3166+
.posting_list_writer
31613167
.get_owned::<u32, &SpannPostingList<'_>>("", 2)
31623168
.await
31633169
.expect("Error getting posting list")
@@ -3174,8 +3180,8 @@ mod tests {
31743180
// Expect the posting lists to be 60. Also validate the ids, versions and embeddings
31753181
// individually.
31763182
{
3177-
let pl_guard = writer.posting_list_writer.lock().await;
3178-
let pl = pl_guard
3183+
let pl = writer
3184+
.posting_list_writer
31793185
.get_owned::<u32, &SpannPostingList<'_>>("", 1)
31803186
.await
31813187
.expect("Error getting posting list")
@@ -3189,7 +3195,8 @@ mod tests {
31893195
assert_eq!(pl.2[(point - 41) * 2], point as f32);
31903196
assert_eq!(pl.2[(point - 41) * 2 + 1], point as f32);
31913197
}
3192-
let pl = pl_guard
3198+
let pl = writer
3199+
.posting_list_writer
31933200
.get_owned::<u32, &SpannPostingList<'_>>("", 2)
31943201
.await
31953202
.expect("Error getting posting list")
@@ -3336,7 +3343,6 @@ mod tests {
33363343
.expect("Error adding to hnsw index");
33373344
}
33383345
{
3339-
let pl_guard = writer.posting_list_writer.lock().await;
33403346
let mut doc_offset_ids = vec![0u32; 100];
33413347
let mut doc_versions = vec![0; 100];
33423348
let mut doc_embeddings = vec![0.0; 200];
@@ -3352,7 +3358,8 @@ mod tests {
33523358
doc_versions: &doc_versions,
33533359
doc_embeddings: &doc_embeddings,
33543360
};
3355-
pl_guard
3361+
writer
3362+
.posting_list_writer
33563363
.set("", 1, &pl)
33573364
.await
33583365
.expect("Error writing to posting list");
@@ -3367,7 +3374,8 @@ mod tests {
33673374
doc_versions: &doc_versions,
33683375
doc_embeddings: &doc_embeddings,
33693376
};
3370-
pl_guard
3377+
writer
3378+
.posting_list_writer
33713379
.set("", 2, &pl)
33723380
.await
33733381
.expect("Error writing to posting list");
@@ -3416,16 +3424,17 @@ mod tests {
34163424
}
34173425
{
34183426
// The posting lists should not be changed at all.
3419-
let pl_guard = writer.posting_list_writer.lock().await;
3420-
let pl = pl_guard
3427+
let pl = writer
3428+
.posting_list_writer
34213429
.get_owned::<u32, &SpannPostingList<'_>>("", 1)
34223430
.await
34233431
.expect("Error getting posting list")
34243432
.unwrap();
34253433
assert_eq!(pl.0.len(), 100);
34263434
assert_eq!(pl.1.len(), 100);
34273435
assert_eq!(pl.2.len(), 200);
3428-
let pl = pl_guard
3436+
let pl = writer
3437+
.posting_list_writer
34293438
.get_owned::<u32, &SpannPostingList<'_>>("", 2)
34303439
.await
34313440
.expect("Error getting posting list")
@@ -3479,8 +3488,8 @@ mod tests {
34793488
}
34803489
// Expect the posting lists with id 1 to be 79.
34813490
{
3482-
let pl_guard = writer.posting_list_writer.lock().await;
3483-
let pl = pl_guard
3491+
let pl = writer
3492+
.posting_list_writer
34843493
.get_owned::<u32, &SpannPostingList<'_>>("", 1)
34853494
.await
34863495
.expect("Error getting posting list")
@@ -3575,7 +3584,6 @@ mod tests {
35753584
let mut split_doc_embeddings3 = vec![0.0; 100];
35763585
{
35773586
let mut rng = rand::thread_rng();
3578-
let pl_guard = writer.posting_list_writer.lock().await;
35793587
for i in 1..=50 {
35803588
// Generate random radius between 0 and 1
35813589
let r = rng.gen::<f32>().sqrt(); // sqrt for uniform distribution
@@ -3597,7 +3605,8 @@ mod tests {
35973605
doc_versions: &split_doc_versions1,
35983606
doc_embeddings: &split_doc_embeddings1,
35993607
};
3600-
pl_guard
3608+
writer
3609+
.posting_list_writer
36013610
.set("", 1, &posting_list)
36023611
.await
36033612
.expect("Error writing to posting list");
@@ -3624,7 +3633,8 @@ mod tests {
36243633
doc_versions: &split_doc_versions3,
36253634
doc_embeddings: &split_doc_embeddings3,
36263635
};
3627-
pl_guard
3636+
writer
3637+
.posting_list_writer
36283638
.set("", 3, &posting_list)
36293639
.await
36303640
.expect("Error writing to posting list");
@@ -3650,7 +3660,8 @@ mod tests {
36503660
doc_versions: &split_doc_versions2,
36513661
doc_embeddings: &split_doc_embeddings2,
36523662
};
3653-
pl_guard
3663+
writer
3664+
.posting_list_writer
36543665
.set("", 2, &posting_list)
36553666
.await
36563667
.expect("Error writing to posting list");
@@ -3678,9 +3689,9 @@ mod tests {
36783689
.expect("Expected reassign to succeed");
36793690
// See the reassigned points.
36803691
{
3681-
let pl_guard = writer.posting_list_writer.lock().await;
36823692
// Center 1 should remain unchanged.
3683-
let pl = pl_guard
3693+
let pl = writer
3694+
.posting_list_writer
36843695
.get_owned::<u32, &SpannPostingList<'_>>("", 1)
36853696
.await
36863697
.expect("Error getting posting list")
@@ -3698,7 +3709,8 @@ mod tests {
36983709
);
36993710
}
37003711
// Center 2 should get 50 points, all with version 2 migrating from center 3.
3701-
let pl = pl_guard
3712+
let pl = writer
3713+
.posting_list_writer
37023714
.get_owned::<u32, &SpannPostingList<'_>>("", 2)
37033715
.await
37043716
.expect("Error getting posting list")
@@ -3718,7 +3730,8 @@ mod tests {
37183730
// Center 3 should get 100 points. 50 points with version 1 which weere
37193731
// originally in center 3 and 50 points with version 2 which were originally
37203732
// in center 2.
3721-
let pl = pl_guard
3733+
let pl = writer
3734+
.posting_list_writer
37223735
.get_owned::<u32, &SpannPostingList<'_>>("", 3)
37233736
.await
37243737
.expect("Error getting posting list")
@@ -3838,7 +3851,6 @@ mod tests {
38383851
let mut doc_embeddings3 = vec![0.0; 140];
38393852
{
38403853
let mut rng = rand::thread_rng();
3841-
let pl_guard = writer.posting_list_writer.lock().await;
38423854
// Insert 70 points within a radius of 1 to center 1.
38433855
for i in 1..=70 {
38443856
// Generate random radius between 0 and 1
@@ -3895,7 +3907,8 @@ mod tests {
38953907
doc_versions: &doc_versions1,
38963908
doc_embeddings: &doc_embeddings1,
38973909
};
3898-
pl_guard
3910+
writer
3911+
.posting_list_writer
38993912
.set("", 1, &spann_posting_list)
39003913
.await
39013914
.expect("Error writing to posting list");
@@ -3904,7 +3917,8 @@ mod tests {
39043917
doc_versions: &doc_versions2,
39053918
doc_embeddings: &doc_embeddings2,
39063919
};
3907-
pl_guard
3920+
writer
3921+
.posting_list_writer
39083922
.set("", 2, &spann_posting_list)
39093923
.await
39103924
.expect("Error writing to posting list");
@@ -3913,7 +3927,8 @@ mod tests {
39133927
doc_versions: &doc_versions3,
39143928
doc_embeddings: &doc_embeddings3,
39153929
};
3916-
pl_guard
3930+
writer
3931+
.posting_list_writer
39173932
.set("", 3, &spann_posting_list)
39183933
.await
39193934
.expect("Error writing to posting list");
@@ -3937,8 +3952,8 @@ mod tests {
39373952
.expect("Error garbage collecting");
39383953
// check the posting lists.
39393954
{
3940-
let pl_guard = writer.posting_list_writer.lock().await;
3941-
let pl = pl_guard
3955+
let pl = writer
3956+
.posting_list_writer
39423957
.get_owned::<u32, &SpannPostingList<'_>>("", 1)
39433958
.await
39443959
.expect("Error getting posting list")
@@ -3955,7 +3970,8 @@ mod tests {
39553970
doc_embeddings1[(point - 1) * 2 + 1]
39563971
);
39573972
}
3958-
let pl = pl_guard
3973+
let pl = writer
3974+
.posting_list_writer
39593975
.get_owned::<u32, &SpannPostingList<'_>>("", 3)
39603976
.await
39613977
.expect("Error getting posting list")

0 commit comments

Comments
 (0)