Add ScaNN index backend with dtype-aware storage and tests#195
Add ScaNN index backend with dtype-aware storage and tests#195NohTow merged 8 commits intolightonai:mainfrom
Conversation
This adds ScaNN index support with fp16/fp32-preserving flattening, shared index utilities, scann extras, and focused tests for verbosity, dtype handling, and embedding retrieval behavior. Co-authored-by: Cursor <cursoragent@cursor.com>
…s test - Run ruff format/lint on scann.py and test_scann.py - Replace per-token Python loop in __call__() result construction with np.split for vectorized reshaping - Add test for numpy document embedding input (float16/float32) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
Hi @robro612, this is a cool PR, LGTM, would it be possible to have a small benchmark on smallest dataset of BEIR benchmark like scifact and few small other to assert that the ndcg@10 is relevant ? As well as the QPS ? |
There was a problem hiding this comment.
Pull request overview
This pull request adds a new ScaNN (Scalable Nearest Neighbors) index backend to pylate, providing an alternative approximate nearest neighbor search option for ColBERT retrieval. The implementation includes dtype-aware storage supporting both fp16 and fp32 embeddings, optional memory logging utilities, and comprehensive test coverage.
Changes:
- Adds ScaNN index implementation with auto-tuning parameters and optional autopilot mode
- Introduces shared utility functions (reshape_embeddings, log_memory) in pylate/indexes/utils.py
- Adds ScaNN optional dependencies with psutil for memory tracking
- Includes comprehensive unit tests covering dtype handling, verbosity configuration, and error paths
- Updates evaluation example script to support ScaNN index type
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| pylate/indexes/scann.py | Complete ScaNN index implementation with build, save, load, query, and embedding retrieval functionality |
| pylate/indexes/utils.py | Shared utility functions for embedding reshaping and memory logging |
| pylate/indexes/init.py | Exports ScaNN class from indexes module |
| pyproject.toml | Adds scann and psutil dependencies to optional dependencies |
| tests/test_scann.py | Comprehensive test suite for ScaNN functionality including dtype handling and error cases |
| examples/evaluation/beir_dataset.py | Adds ScaNN as an index type option with fp16 model support |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
pylate/indexes/scann.py
Outdated
|
|
||
| def __init__( | ||
| self, | ||
| name: str | None = "ScaNN_index", |
There was a problem hiding this comment.
Parameter naming inconsistency: The ScaNN class uses name as a parameter, but all other index classes in this codebase (PLAID, Voyager, FastPlaid, StanfordPlaid) use index_name. This breaks API consistency. The parameter should be renamed to index_name to match the established pattern.
pylate/indexes/scann.py
Outdated
| override: bool = False, | ||
| verbose_level: str | None = None, | ||
| ) -> None: | ||
| self.name = name |
There was a problem hiding this comment.
The name attribute is stored but never used anywhere in the implementation. Looking at line 104, self.name = name is set, but it's only used in _get_index_path() (line 260). However, the docstring on line 43-44 states it's "The name of the index collection", suggesting it should be consistently used and documented.
pylate/indexes/scann.py
Outdated
| emb_np = emb.to( | ||
| "cpu", | ||
| dtype=torch.float16 if np_dtype == np.float16 else torch.float32, | ||
| ).numpy() |
There was a problem hiding this comment.
The to() conversion specifies a dtype that will always match the input tensor's dtype due to the check on line 562-566. The dtype parameter in the to() call is redundant since line 562-566 ensures all embeddings already have the same dtype as np_dtype. Consider simplifying to emb.to("cpu").numpy() or document why explicit dtype conversion is needed here.
| emb_np = emb.to( | |
| "cpu", | |
| dtype=torch.float16 if np_dtype == np.float16 else torch.float32, | |
| ).numpy() | |
| emb_np = emb.to("cpu").numpy() |
|
@raphaelsty certainly, I'll have those numbers as a consequence of testing the training PRs soon - keep in mind ScaNN is cpu-only so it's rather slow compared to FastPLAID, but was easier for me to use than Voyage even with the huge memory overhead of storing the flattened embs. I'll also take a look at the fixes the bot suggested ^. |
Make batch_size a no-op default arg in ScaNN (requires all docs at once, keep argument to maintain parity with other index classes) Type/Docstring annotation fixes in retriever class
Update: Copilot fixes + BEIR benchmarksChanges in this push
Benchmark resultsModel: nDCG@10
QPS (queries per second)
Notes
|
|
LGTM the MR is very clean @robro612, thank you for the evaluation results , I'll run the CI and then merge :) |
NohTow
left a comment
There was a problem hiding this comment.
Hey!
Thanks for the amazing work and sorry for the delay in the review, busy days!
I've added a few comments, most of them are nit but I figured it could help making things a bit cleaner (also some probable stupid questions, but I'd rather ask and find out I'm dumb rather than merging errors because I did not ask!)
Besides those, I think my main "comment" is that I wonder whether we should merge the part about time/memory profiling. It's very nice of you to have added all of those and go the extra miles for benching the things, but I wonder if it's something we expect in the merged indexes.
On a related note, I wonder if we should merge examples/evaluation/benchmark_index_beir.py and if so, I do not think it should be in this folder imho
examples/evaluation/beir_dataset.py
Outdated
| document_length=300, | ||
| query_length=query_len.get(dataset_name), | ||
| ) | ||
| ).to(torch.float16) |
There was a problem hiding this comment.
probably a nit because fp16 is almost == to fp32 but I wonder if this should be an option
should be noted that i am using fp16 for most of my benches these days to save some memory for large datasets (i need to fix bf16 models that outputs fp32 because of numpy and thus have to be recasted)
pylate/indexes/scann.py
Outdated
|
|
||
| def _build_searcher(self, embeddings: np.ndarray) -> None: | ||
| """Build the ScaNN searcher from embeddings (in-memory only).""" | ||
| build_start = time.time() |
There was a problem hiding this comment.
Do we still need those though?
It's to run the bench right? I wonder if we should let bench params within final PR
pylate/indexes/scann.py
Outdated
| ) | ||
|
|
||
| # Build ScaNN searcher | ||
| log_memory("Before scann.build()", self.verbose) |
There was a problem hiding this comment.
A bit of a broad comment around the whole PR, but I wonder if we should leave time/memory profiling in the final merged thing
| logger.warning( | ||
| f"[ScaNN] WARNING: Manual parameters provided but will be ignored: num_leaves={self.num_leaves}, num_leaves_to_search={self.num_leaves_to_search}, training_sample_size={self.training_sample_size}" | ||
| ) | ||
| else: |
There was a problem hiding this comment.
Sorry if dumb question but from my understanding here we were to use autopilot when params not set.
Seems like you are defining some defaults but not setting self.autopilot to True.
Were you referring to "auto tune" like use default or am I missing anything?
There was a problem hiding this comment.
autopilot is a config setting directly in the ScaNN library. AFAICT it sets reasonable settings, but I don't know how exactly. Recalling my experiments, It's slower/more accurate than the defaults that I set which come directly from the XTR inference notebook.
pylate/indexes/scann.py
Outdated
| ) | ||
|
|
||
| metadata_path = index_path / "metadata.json" | ||
| doc_id_mapping_path = index_path / "doc_id_to_embedding_range.tsv" |
There was a problem hiding this comment.
Am I a pain asking whether we could use the same type of processing than for the other indexes?
We went from sqlitdict to pickle, which should be pretty easy to implement (tsv -> dict pickled)
pylate/indexes/scann.py
Outdated
| def __call__( | ||
| self, | ||
| queries_embeddings: list[list[int | float]], | ||
| k: int = 5, |
There was a problem hiding this comment.
biggest nit of my life but default k for other indexes is 10
pylate/indexes/scann.py
Outdated
| If subset is provided (not yet implemented). | ||
|
|
||
| """ | ||
| if subset is not None: |
There was a problem hiding this comment.
Wonder if we should expose the param at all then, it's not exposed for stanford plaid only fastplaid
pyproject.toml
Outdated
| "pytest-xdist >=3.6.0", | ||
| "pytest-rerunfailures >= 15.0.0", | ||
| "pytest >= 8.2.1", | ||
| "psutil >= 7.2.2", |
There was a problem hiding this comment.
Why do we need it in main dep?
- Remove profiling/timing code (time.time, log_memory) from ScaNN index - Remove psutil dependency from scann optional and dev deps - Remove benchmark script from PR - Switch doc_id_to_embedding_range storage from TSV to pickle - Move _to_np_dtype helper to utils module as np_dtype_for - Change default k from 5 to 10 to match other indexes - Remove subset param from ScaNN.__call__ (not implemented) - Add comment explaining NaN distances from ScaNN - Add save/load round-trip test for pickle serialization
Addressing review commentsThanks for the thorough review @NohTow! Here's what was addressed:
Also added a save/load round-trip test covering the pickle serialization. |
Better comment to not conflate our defaults with .autopilot() ScaNN feature Co-authored-by: Antoine Chaffin <38869395+NohTow@users.noreply.github.com>
|
Thanks for the PR! |
Summary
ScaNNindex backend and export it frompylate.indexespyproject.tomlTest plan
python -m compileall pylate/indexes/scann.py tests/test_scann.pypython -m pytest tests/test_scann.py -qpython examples/evaluation/beir_dataset.py --index_type scann --dataset dataset_name nfcorpus