diff --git a/pylate/indexes/__init__.py b/pylate/indexes/__init__.py index 97501bf8..2ae4cd9b 100644 --- a/pylate/indexes/__init__.py +++ b/pylate/indexes/__init__.py @@ -1,6 +1,7 @@ from __future__ import annotations from .plaid import PLAID +from .scann import ScaNN from .voyager import Voyager -__all__ = ["Voyager", "PLAID"] +__all__ = ["Voyager", "PLAID", "ScaNN"] diff --git a/pylate/indexes/plaid.py b/pylate/indexes/plaid.py index 8530adf1..a55c773c 100644 --- a/pylate/indexes/plaid.py +++ b/pylate/indexes/plaid.py @@ -216,11 +216,12 @@ def add_documents( **kwargs, ) -> "PLAID": """Add documents to the index.""" - return self._index.add_documents( + self._index.add_documents( documents_ids=documents_ids, documents_embeddings=documents_embeddings, **kwargs, ) + return self def remove_documents(self, documents_ids: list[str]) -> "PLAID": """Remove documents from the index. diff --git a/pylate/indexes/scann.py b/pylate/indexes/scann.py new file mode 100644 index 00000000..3faab12a --- /dev/null +++ b/pylate/indexes/scann.py @@ -0,0 +1,681 @@ +from __future__ import annotations + +import gc +import itertools +import json +import logging +import pickle +from pathlib import Path + +import numpy as np +import torch +from tqdm.auto import tqdm + +from .base import Base +from .utils import np_dtype_for, reshape_embeddings + +logger = logging.getLogger(__name__) + + +class ScaNN(Base): + """ScaNN index. The ScaNN index is a fast and efficient index for approximate nearest neighbor search. + + **Important Notes:** + - ScaNN is an **approximate** nearest neighbor search (not exact), designed for large-scale datasets + - For ColBERT retrieval, PLAID is typically faster and more accurate as it's optimized for ColBERT scoring + - ScaNN is CPU-only (no GPU acceleration) + - Parameters are auto-tuned based on dataset size if not specified + + To use this index, you need to install the `scann` extra: + + ```bash + pip install "pylate[scann]" + ``` + + or install scann directly: + + ```bash + pip install scann + ``` + + Parameters + ---------- + index_name + The name of the index collection. + embedding_size + The number of dimensions of the embeddings. + num_neighbors + The number of neighbors to use for the ScaNN searcher. + num_leaves + The number of leaves in the ScaNN tree. If None, auto-tuned based on dataset size. + For small datasets (<100K vectors), fewer leaves are used for speed. + num_leaves_to_search + The number of leaves to search during query time. If None, auto-tuned based on dataset size. + Higher values improve recall but slow down search. + dimensions_per_block + The number of dimensions to use for each block. If None, auto-tuned based on dataset size. + Defaults to 2. + anisotropic_quantization_threshold + The threshold for anisotropic quantization. If None, auto-tuned based on dataset size. + Defaults to 0.2. + training_sample_size + The number of samples to use for training the ScaNN index. + verbose + Verbosity configuration: + - ``False`` or ``"none"``: disable logs + - ``True`` or ``"init"`` or ``"all"``: log build/load/indexing + verbose_level + Backward-compatible alias for verbosity scope (``"none"``, ``"init"``, + ``"all"``). If set, it overrides ``verbose``. + use_autopilot + Whether to use ScaNN's autopilot() method for automatic parameter tuning. + If True, overrides num_leaves, num_leaves_to_search, and training_sample_size. + Defaults to False. + store_embeddings + Whether to store the embeddings in the index. If True, the embeddings will be stored in the index. + Defaults to True. This is required to use the get_documents_embeddings method. + index_folder + The folder where the index will be saved/loaded. If None, indices are not persisted to disk. + Defaults to None. + override + Whether to override the index if it already exists. If False and index exists, it will be loaded. + Defaults to False. + + """ + + def __init__( + self, + index_name: str | None = "ScaNN_index", + embedding_size: int = 128, + num_neighbors: int | None = 10, + num_leaves: int | None = None, + num_leaves_to_search: int | None = None, + dimensions_per_block: int | None = 2, + anisotropic_quantization_threshold: float | None = 0.2, + training_sample_size: int | None = None, + verbose: bool | str = "none", + use_autopilot: bool = False, + store_embeddings: bool = True, + index_folder: str | None = None, + override: bool = False, + verbose_level: str | None = None, + ) -> None: + self.index_name = index_name + self.embedding_size = embedding_size + self.num_neighbors = num_neighbors + self.verbose_level = ( + verbose_level + if verbose_level is not None + else self._normalize_verbose(verbose=verbose) + ) + if self.verbose_level not in {"none", "init", "all"}: + raise ValueError( + f"Invalid verbosity level: {self.verbose_level}. " + "Expected one of: 'none', 'init', 'all'." + ) + self.verbose = self.verbose_level in ("init", "all") + self.num_leaves = num_leaves + self.num_leaves_to_search = num_leaves_to_search + self.dimensions_per_block = dimensions_per_block + self.anisotropic_quantization_threshold = anisotropic_quantization_threshold + self.training_sample_size = training_sample_size + self.use_autopilot = use_autopilot + self.store_embeddings = store_embeddings + self.index_folder = index_folder + self.override = override + + # In-memory data structures + self.searcher = None + # Note: embedding_id == position (sequential IDs), so no need for separate mappings + # Store (start, length) tuples instead of lists for memory efficiency + self.doc_id_to_embedding_range = {} # doc_id -> (start_position, length) tuple + self.position_to_doc_id = None # Direct mapping: position -> document ID (numpy array for vectorized indexing) + self.flattened_embeddings = ( + None # Flattened embeddings array (only if store_embeddings=True) + ) + self._documents_added = False # Track if documents have been added + + # Load existing index if index_folder is provided, override is False, and index exists + if self.index_folder is not None and not self.override: + index_path = self._get_index_path() + if index_path is not None: + scann_config_path = index_path / "scann_config.pb" + metadata_path = index_path / "metadata.json" + if scann_config_path.exists() and metadata_path.exists(): + self._load_index() + + @staticmethod + def _normalize_verbose(verbose: bool | str) -> str: + """Normalize `verbose` input to internal string levels.""" + if isinstance(verbose, bool): + return "init" if verbose else "none" + return verbose + + def _build_searcher(self, embeddings: np.ndarray) -> None: + """Build the ScaNN searcher from embeddings (in-memory only).""" + try: + import scann + except ImportError: + raise ImportError( + 'ScaNN is not installed. Please install it with: `pip install "pylate[scann]"` or `pip install scann`.' + ) + + # Auto-tune parameters if not set (only if not using autopilot) + num_vectors = embeddings.shape[0] + self.num_neighbors = ( + self.num_neighbors + if self.num_neighbors is not None + else min(10, num_vectors) + ) + + if self.use_autopilot: + if self.verbose: + logger.info( + f"[ScaNN] Building ScaNN searcher with {embeddings.shape[0]} vectors using autopilot()..." + ) + if ( + self.num_leaves is not None + or self.num_leaves_to_search is not None + or self.training_sample_size is not None + ): + logger.warning( + "[ScaNN] autopilot() overrides manual configuration (num_leaves, num_leaves_to_search, training_sample_size)" + ) + else: + # Use default if not set + self.num_leaves = ( + self.num_leaves + if self.num_leaves is not None + else min(2_000, num_vectors) + ) + self.num_leaves_to_search = ( + self.num_leaves_to_search + if self.num_leaves_to_search is not None + else 200 + ) + self.training_sample_size = ( + self.training_sample_size + if self.training_sample_size is not None + else min(250000, num_vectors) + ) + + if self.verbose: + logger.info( + f"[ScaNN] Building ScaNN searcher with {embeddings.shape[0]} vectors..." + ) + logger.info( + f"[ScaNN] Parameters: num_leaves={self.num_leaves}, num_leaves_to_search={self.num_leaves_to_search}, training_sample_size={self.training_sample_size}, num_neighbors={self.num_neighbors}" + ) + + # Build ScaNN searcher + if self.use_autopilot: + searcher = ( + scann.scann_ops_pybind.builder( + embeddings, self.num_neighbors, "dot_product" + ) + .autopilot() + .build() + ) + else: + searcher = ( + scann.scann_ops_pybind.builder( + embeddings, self.num_neighbors, "dot_product" + ) + .tree( + num_leaves=self.num_leaves, + num_leaves_to_search=self.num_leaves_to_search, + training_sample_size=self.training_sample_size, + spherical=True, + ) + .score_ah( + dimensions_per_block=self.dimensions_per_block, + anisotropic_quantization_threshold=self.anisotropic_quantization_threshold, + ) + .build() + ) + + self.searcher = searcher + + def _get_index_path(self) -> Path | None: + """Get the path where the index should be saved/loaded.""" + if self.index_folder is None or self.index_name is None: + return None + index_path = Path(self.index_folder) / self.index_name + return index_path + + def _load_index(self) -> None: + """Load an existing index from disk. Raises an error if loading fails.""" + index_path = self._get_index_path() + if index_path is None: + raise ValueError( + f"Cannot load index: index_folder or index_name not set. " + f"index_folder={self.index_folder}, index_name={self.index_name}" + ) + + metadata_path = index_path / "metadata.json" + doc_id_mapping_path = index_path / "doc_id_to_embedding_range.pkl" + flattened_embeddings_path = index_path / "flattened_embeddings.npy" + + try: + import scann + except ImportError: + raise ImportError( + "ScaNN is not installed. Cannot load index. " + 'Please install it with: `pip install "pylate[scann]"` or `pip install scann`.' + ) + + try: + if self.verbose: + logger.info(f"[ScaNN] Loading existing index from {index_path}...") + + # Load searcher - use absolute path to avoid path resolution issues + index_path_abs = index_path.resolve() + self.searcher = scann.scann_ops_pybind.load_searcher(str(index_path_abs)) + + # Load metadata (JSON) + with open(metadata_path, "r") as f: + metadata = json.load(f) + # Restore configuration from metadata + self.embedding_size = metadata.get( + "embedding_size", self.embedding_size + ) + self.num_neighbors = metadata.get("num_neighbors", self.num_neighbors) + self.num_leaves = metadata.get("num_leaves", self.num_leaves) + self.num_leaves_to_search = metadata.get( + "num_leaves_to_search", self.num_leaves_to_search + ) + self.training_sample_size = metadata.get( + "training_sample_size", self.training_sample_size + ) + self.use_autopilot = metadata.get("use_autopilot", self.use_autopilot) + self.store_embeddings = metadata.get( + "store_embeddings", self.store_embeddings + ) + + # Load doc_id_to_embedding_range (pickle) + if doc_id_mapping_path.exists(): + with open(doc_id_mapping_path, "rb") as f: + self.doc_id_to_embedding_range = pickle.load(f) + else: + raise FileNotFoundError( + f"Document ID mapping not found at {doc_id_mapping_path}" + ) + + # Reconstruct position_to_doc_id from doc_id_to_embedding_range + if self.doc_id_to_embedding_range: + # Calculate total embeddings from the max end position + max_end = max( + start + length + for start, length in self.doc_id_to_embedding_range.values() + ) + self.position_to_doc_id = np.empty(max_end, dtype=object) + for doc_id, (start, length) in tqdm( + self.doc_id_to_embedding_range.items(), + desc="Reconstructing position_to_doc_id", + disable=not self.verbose, + ): + self.position_to_doc_id[start : start + length] = doc_id + else: + self.position_to_doc_id = np.empty(0, dtype=object) + + # Load flattened_embeddings if it exists (only if store_embeddings=True) + if self.store_embeddings and flattened_embeddings_path.exists(): + if self.verbose: + logger.info( + "[ScaNN] Loading flattened_embeddings from %s...", + flattened_embeddings_path, + ) + self.flattened_embeddings = np.load(flattened_embeddings_path) + if self.verbose: + logger.info( + "[ScaNN] Loaded flattened_embeddings with shape %s", + self.flattened_embeddings.shape, + ) + else: + self.flattened_embeddings = None + + self._documents_added = True + + if self.verbose: + logger.info(f"[ScaNN] Successfully loaded index from {index_path}") + logger.info( + f"[ScaNN] Documents: {len(self.doc_id_to_embedding_range)}" + ) + logger.info( + f"[ScaNN] Total embeddings: {len(self.position_to_doc_id) if self.position_to_doc_id is not None else 0}" + ) + except ImportError: + # Preserve import errors (e.g. optional dependencies) as-is. + raise + except Exception as e: + raise RuntimeError( + f"Failed to load ScaNN index from {index_path}: {e}. " + f"This may indicate a corrupted index or version mismatch. " + f"Set override=True to rebuild the index." + ) from e + + def save(self) -> None: + """Save the index to disk.""" + if self.searcher is None: + raise ValueError( + "Cannot save index: no searcher has been built. Add documents first." + ) + + index_path = self._get_index_path() + if index_path is None: + if self.verbose: + logger.warning( + "[ScaNN] Cannot save index: index_folder or index_name not set" + ) + return + + # Create directory if it doesn't exist + index_path.mkdir(parents=True, exist_ok=True) + + metadata_path = index_path / "metadata.json" + doc_id_mapping_path = index_path / "doc_id_to_embedding_range.pkl" + flattened_embeddings_path = index_path / "flattened_embeddings.npy" + + try: + if self.verbose: + logger.info(f"[ScaNN] Saving index to {index_path}...") + + # Save searcher - serialize() expects a directory path and will create files inside it + # Use absolute path to avoid path resolution issues when loading + # Serialize directly to index_path (not a subdirectory) to avoid path issues + index_path_abs = index_path.resolve() + self.searcher.serialize(str(index_path_abs)) + + # Save metadata as JSON (only simple, serializable values) + metadata = { + "embedding_size": self.embedding_size, + "num_neighbors": self.num_neighbors, + "num_leaves": self.num_leaves, + "num_leaves_to_search": self.num_leaves_to_search, + "training_sample_size": self.training_sample_size, + "use_autopilot": self.use_autopilot, + "store_embeddings": self.store_embeddings, + } + + with open(metadata_path, "w") as f: + json.dump(metadata, f, indent=2) + + # Save doc_id_to_embedding_range as pickle + # position_to_doc_id can be reconstructed from this, so we don't save it separately + with open(doc_id_mapping_path, "wb") as f: + pickle.dump(self.doc_id_to_embedding_range, f) + + # Save flattened_embeddings if store_embeddings=True + if self.store_embeddings and self.flattened_embeddings is not None: + np.save(flattened_embeddings_path, self.flattened_embeddings) + + if self.verbose: + logger.info(f"[ScaNN] Index saved successfully to {index_path}") + except Exception as e: + logger.error(f"[ScaNN] Failed to save index to {index_path}: {e}") + raise + + def add_documents( + self, + documents_ids: list[str], + documents_embeddings: list[torch.Tensor | np.ndarray], + batch_size: int = 128, + ) -> "ScaNN": + """Add documents to the index. + + Note: This method only supports adding all documents at once. + Subsequent calls will raise an error. + batch_size is kept for API compatibility but not used. + """ + # Enforce single add - check if documents already exist + if self._documents_added: + raise ValueError( + "ScaNN index only supports adding all documents at once. " + "Documents have already been added." + ) + + if self.verbose: + logger.info(f"[ScaNN] Adding {len(documents_ids)} documents to index...") + + if len(documents_embeddings) == 0: + raise ValueError("Cannot add documents: documents_embeddings is empty.") + + # Get doc lengths and total count in one pass + doc_lengths = [emb.shape[0] for emb in documents_embeddings] + total_embeddings = sum(doc_lengths) + embedding_dim = documents_embeddings[0].shape[1] + + # Preserve incoming dtype (fp16/fp32) to avoid holding two full flattened + # copies (e.g., one fp16 and one fp32) in memory at the same time. + first_emb = documents_embeddings[0] + + if not isinstance(first_emb, (torch.Tensor, np.ndarray)): + raise ValueError( + "ScaNN expects document embeddings to be torch.Tensor or np.ndarray. " + f"Got type={type(first_emb)}." + ) + np_dtype = np_dtype_for(first_emb.dtype) + if np_dtype is None: + raise ValueError( + "ScaNN expects document embeddings to be float16 or float32. " + f"Got dtype={first_emb.dtype}." + ) + + if self.verbose: + size_gb = ( + total_embeddings * embedding_dim * np.dtype(np_dtype).itemsize / 1e9 + ) + logger.info( + f"[ScaNN] Pre-allocating array for {total_embeddings} embeddings x " + f"{embedding_dim} dims ({size_gb:.2f} GB) with dtype={np_dtype}" + ) + + # Pre-allocate flattened array in the same dtype as incoming embeddings. + flattened_embeddings = np.empty( + (total_embeddings, embedding_dim), dtype=np_dtype + ) + + # Fill array in-place, deleting each tensor after copying to free memory + offset = 0 + num_docs = len(documents_embeddings) + + iterator = tqdm( + enumerate(documents_embeddings), + desc="Flattening documents and adding to pre-allocated array", + total=num_docs, + disable=not self.verbose, + ) + for i, emb in iterator: + if emb.shape[1] != embedding_dim: + raise ValueError( + "All document embeddings must have the same embedding dimension. " + f"Expected {embedding_dim}, got {emb.shape[1]}." + ) + + if not isinstance(emb, (torch.Tensor, np.ndarray)): + raise ValueError( + "ScaNN expects document embeddings to be torch.Tensor or np.ndarray. " + f"Got type={type(emb)}." + ) + emb_np_dtype = np_dtype_for(emb.dtype) + if emb_np_dtype is None: + raise ValueError( + "ScaNN expects document embeddings to be float16 or float32. " + f"Got dtype={emb.dtype}." + ) + + if isinstance(emb, torch.Tensor): + emb_np = emb.to("cpu").numpy() + else: + emb_np = emb + + if emb_np_dtype != np_dtype: + raise ValueError( + "All document embeddings must have the same dtype. " + f"Expected {np_dtype}, got {emb_np_dtype}." + ) + n = emb_np.shape[0] + flattened_embeddings[offset : offset + n] = emb_np + offset += n + + # Clear the list and run gc + del documents_embeddings + gc.collect() + + # Build position->doc_id array and doc_id->embedding_range mapping + self.position_to_doc_id = np.empty(total_embeddings, dtype=object) + offset = 0 + for doc_id, num_tokens in zip(documents_ids, doc_lengths): + # Store (start, length) tuple instead of list for memory efficiency + self.doc_id_to_embedding_range[doc_id] = (offset, num_tokens) + # Broadcast doc_id to fill the slice (no temp list needed) + self.position_to_doc_id[offset : offset + num_tokens] = doc_id + offset += num_tokens + + # Build the ScaNN index with all embeddings + if len(flattened_embeddings) > 0: + if self.verbose: + logger.info( + f"[ScaNN] Building index with {len(flattened_embeddings)} embeddings..." + ) + + self._build_searcher(flattened_embeddings) + + # Store flattened embeddings if requested, otherwise free the array + if self.store_embeddings: + self.flattened_embeddings = flattened_embeddings + else: + del flattened_embeddings + gc.collect() + + # Mark that documents have been added + self._documents_added = True + + # Save index to disk if index_folder is set + if self.index_folder is not None: + self.save() + + return self + + def remove_documents(self, documents_ids: list[str]) -> None: + """Remove documents from the index. + + Not supported for ScaNN index. + + Parameters + ---------- + documents_ids + The documents IDs to remove. + + Raises + ------ + NotImplementedError + Document removal is not supported for ScaNN index. + + """ + raise NotImplementedError("Document removal is not supported for ScaNN index.") + + def __call__( + self, + queries_embeddings: list[list[int | float]], + k: int = 10, + ) -> dict: + """Query the index for the nearest neighbors of the queries embeddings. + + Parameters + ---------- + queries_embeddings + The queries embeddings. + k + The number of nearest neighbors to return. + + """ + if self.searcher is None: + raise ValueError("Index is empty, add documents before querying.") + + # Reshape queries + queries_embeddings = reshape_embeddings(embeddings=queries_embeddings) + + # Flatten query embeddings (assume they are already normalized) + flattened_queries = np.array(list(itertools.chain(*queries_embeddings))) + + # Query the index + neighbors, distances = self.searcher.search_batched_parallel( + flattened_queries, final_num_neighbors=k + ) + # ScaNN may return NaN distances when it cannot complete the full top-k + # (e.g. k exceeds the number of indexed documents). Replace with 0. + if np.isnan(distances).any(): + logger.warning( + "[ScaNN] distances has %d NaN values out of %d total; replacing with 0", + np.isnan(distances).sum(), + distances.size, + ) + distances = np.nan_to_num(distances, nan=0.0) + + # Map embedding indices back to document IDs using fully vectorized numpy operations + n_tokens_per_query = [len(q) for q in queries_embeddings] + + # Vectorized lookup: process all tokens at once using numpy advanced indexing + # neighbors shape: (n_tokens_total, k), distances shape: (n_tokens_total, k) + all_doc_ids = self.position_to_doc_id[neighbors] + all_distances = distances + + # Reshape back into nested structure (queries -> tokens -> neighbors) + # using np.split to avoid a per-token Python loop. + splits = np.cumsum(n_tokens_per_query[:-1]) + documents = [list(chunk) for chunk in np.split(all_doc_ids, splits)] + distances_list = [list(chunk) for chunk in np.split(all_distances, splits)] + + return { + "documents_ids": documents, + "distances": distances_list, # Keep as list to handle variable-length query tokens (ragged) + } + + def get_documents_embeddings( + self, documents_ids: list[list[str]] + ) -> list[list[np.ndarray]]: + """Get document embeddings by their IDs. + + Parameters + ---------- + documents_ids + Nested list of document IDs. Each inner list represents a group of documents. + + Returns + ------- + list[list[np.ndarray]] + Nested list of embeddings. Each embedding is a numpy array with shape (seq_len, dim). + + Raises + ------ + NotImplementedError + If store_embeddings=False (embeddings are not stored). + ValueError + If index is empty or document ID not found. + """ + if not self.store_embeddings: + raise NotImplementedError( + "Retrieving document embeddings requires store_embeddings=True. " + "Set store_embeddings=True when creating the index." + ) + + if self.flattened_embeddings is None: + raise ValueError( + "Index is empty, add documents before retrieving embeddings." + ) + + reconstructed_embeddings = [] + for doc_group in documents_ids: + group_embeddings = [] + for doc_id in doc_group: + if doc_id not in self.doc_id_to_embedding_range: + raise ValueError(f"Document ID '{doc_id}' not found in index.") + + start, length = self.doc_id_to_embedding_range[doc_id] + # Slice the flattened array to get document embeddings + doc_emb = self.flattened_embeddings[start : start + length] + group_embeddings.append(doc_emb) + reconstructed_embeddings.append(group_embeddings) + + return reconstructed_embeddings diff --git a/pylate/indexes/utils.py b/pylate/indexes/utils.py new file mode 100644 index 00000000..e3e239f1 --- /dev/null +++ b/pylate/indexes/utils.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +import numpy as np +import torch + + +def reshape_embeddings( + embeddings: np.ndarray | torch.Tensor | list, +) -> np.ndarray | list: + """Reshape embeddings to the expected format (batch_size, n_tokens, embedding_size).""" + if isinstance(embeddings, np.ndarray): + if len(embeddings.shape) == 2: + return np.expand_dims(a=embeddings, axis=0) + + if isinstance(embeddings, torch.Tensor): + return reshape_embeddings(embeddings=embeddings.cpu().detach().numpy()) + + if isinstance(embeddings, list) and isinstance(embeddings[0], torch.Tensor): + return [embedding.cpu().detach().numpy() for embedding in embeddings] + + return embeddings + + +def np_dtype_for( + dtype: object, +) -> type[np.float16] | type[np.float32] | None: + """Map a torch or numpy dtype to the corresponding numpy float type. + + Returns ``np.float16`` or ``np.float32`` for recognised dtypes, + ``None`` otherwise. + """ + if dtype in (torch.float16, np.float16): + return np.float16 + if dtype in (torch.float32, np.float32): + return np.float32 + return None diff --git a/pylate/retrieve/colbert.py b/pylate/retrieve/colbert.py index 07f3dc14..daac96d0 100644 --- a/pylate/retrieve/colbert.py +++ b/pylate/retrieve/colbert.py @@ -5,7 +5,8 @@ import numpy as np import torch -from ..indexes import PLAID, Voyager +from ..indexes.base import Base +from ..indexes.plaid import PLAID from ..rank import RerankResult, rerank from ..utils import iter_batch @@ -17,8 +18,9 @@ class ColBERT: Parameters ---------- - index: - The index to use for retrieval. + index + The index to use for retrieval. Any index that extends ``Base`` + (e.g. PLAID, Voyager, ScaNN). Examples -------- @@ -92,7 +94,7 @@ class ColBERT: """ - def __init__(self, index: Voyager | PLAID) -> None: + def __init__(self, index: Base) -> None: self.index = index def retrieve( @@ -113,21 +115,22 @@ def retrieve( k The number of documents to retrieve. k_token - The number of documents to retrieve from the index. Defaults to `k`. + The number of token-level candidates to retrieve from the index + before reranking. Only used for non-PLAID indexes. Defaults to 100. device - The device to use for the embeddings. Defaults to queries_embeddings device. + The device to use for reranking. Defaults to queries_embeddings device. batch_size - The batch size to use for retrieval. + The batch size to use for retrieval. Only used for non-PLAID indexes. subset Optional subset of document IDs to restrict search to. Can be a single list (same filter for all queries) or list of lists (different filter per query). Document IDs should match the IDs used when adding documents. - Only supported with PLAID index. + Currently only supported with PLAID index. """ - # PLAID index directly retrieves the documents - if isinstance(self.index, PLAID) or not isinstance(self.index, Voyager): + # PLAID handles reranking internally and returns RerankResult directly + if isinstance(self.index, PLAID): return self.index( queries_embeddings=queries_embeddings, k=k, diff --git a/pyproject.toml b/pyproject.toml index c2bd894c..f8d9181b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,6 +56,8 @@ dev = [ eval = ["ranx >= 0.3.16", "beir >= 2.0.0"] api = ["fastapi >= 0.114.1", "uvicorn >= 0.30.6", "batched >= 0.1.2"] voyager = ["voyager >= 2.0.9"] +scann = ["scann >= 1.4.2"] + [dependency-groups] docs = [ diff --git a/tests/test_scann.py b/tests/test_scann.py new file mode 100644 index 00000000..16af73d7 --- /dev/null +++ b/tests/test_scann.py @@ -0,0 +1,217 @@ +from __future__ import annotations + +import numpy as np +import pytest +import torch + +from pylate import indexes + +pytest.importorskip("scann") + + +def test_scann_is_exported() -> None: + """ScaNN should be importable from pylate.indexes.""" + assert hasattr(indexes, "ScaNN") + + +@pytest.mark.parametrize( + ("verbose_input", "expected_level"), + [ + (False, "none"), + (True, "init"), + ("none", "none"), + ("init", "init"), + ("all", "all"), + ], +) +def test_scann_verbose_normalization( + verbose_input: bool | str, expected_level: str +) -> None: + """ScaNN should normalize bool/string verbosity to internal levels.""" + index = indexes.ScaNN(verbose=verbose_input) + assert index.verbose_level == expected_level + assert index.verbose == (expected_level in ("init", "all")) + + +def test_scann_verbose_level_alias_overrides_verbose() -> None: + """Backward-compatible verbose_level should override verbose.""" + index = indexes.ScaNN(verbose=False, verbose_level="all") + assert index.verbose_level == "all" + assert index.verbose is True + + +def test_scann_invalid_verbose_level_raises() -> None: + """Unsupported verbose values should raise a clear ValueError.""" + with pytest.raises(ValueError, match="Invalid verbosity level"): + indexes.ScaNN(verbose="loud") + + +def test_scann_add_documents_returns_self_and_preserves_fp16_storage() -> None: + """Stored flattened embeddings should preserve incoming fp16 dtype.""" + index = indexes.ScaNN(store_embeddings=True) + + documents_ids = ["d1", "d2"] + documents_embeddings = [ + torch.randn(12, 8, dtype=torch.float16), + torch.randn(10, 8, dtype=torch.float16), + ] + + returned = index.add_documents( + documents_ids=documents_ids, + documents_embeddings=documents_embeddings, + batch_size=2, + ) + + assert returned is index + assert index.flattened_embeddings is not None + assert index.flattened_embeddings.dtype == np.float16 + assert index.flattened_embeddings.shape == (22, 8) + + +@pytest.mark.parametrize("docs_dtype", [torch.float16, torch.float32]) +@pytest.mark.parametrize("queries_dtype", [torch.float16, torch.float32]) +def test_scann_accepts_fp16_and_fp32_documents_and_queries( + docs_dtype: torch.dtype, queries_dtype: torch.dtype +) -> None: + """ScaNN should accept both fp16/fp32 docs and fp16/fp32 queries.""" + index = indexes.ScaNN(store_embeddings=True) + index.add_documents( + documents_ids=["d1", "d2", "d3"], + documents_embeddings=[ + torch.randn(8, 8, dtype=docs_dtype), + torch.randn(7, 8, dtype=docs_dtype), + torch.randn(9, 8, dtype=docs_dtype), + ], + batch_size=3, + ) + + expected_np_dtype = np.float16 if docs_dtype == torch.float16 else np.float32 + assert index.flattened_embeddings is not None + assert index.flattened_embeddings.dtype == expected_np_dtype + + results = index( + queries_embeddings=[torch.randn(2, 8, dtype=queries_dtype)], + k=2, + ) + + assert set(results.keys()) == {"documents_ids", "distances"} + assert len(results["documents_ids"]) == 1 + assert len(results["documents_ids"][0]) == 2 + assert len(results["documents_ids"][0][0]) == 2 + + +def test_scann_errors_on_mixed_embedding_dtypes() -> None: + """All document tensors must share dtype (fp16 or fp32).""" + index = indexes.ScaNN(store_embeddings=False) + + with pytest.raises(ValueError, match="same dtype"): + index.add_documents( + documents_ids=["d1", "d2"], + documents_embeddings=[ + torch.randn(2, 8, dtype=torch.float16), + torch.randn(2, 8, dtype=torch.float32), + ], + batch_size=2, + ) + + +@pytest.mark.parametrize("docs_dtype", [np.float16, np.float32]) +def test_scann_accepts_numpy_document_embeddings(docs_dtype: np.dtype) -> None: + """ScaNN should accept numpy doc embeddings and preserve their dtype.""" + index = indexes.ScaNN(store_embeddings=True) + docs = [ + np.random.randn(12, 8).astype(docs_dtype), + np.random.randn(10, 8).astype(docs_dtype), + ] + index.add_documents( + documents_ids=["d1", "d2"], + documents_embeddings=docs, + batch_size=2, + ) + + assert index.flattened_embeddings is not None + assert index.flattened_embeddings.dtype == docs_dtype + + +@pytest.mark.parametrize("docs_dtype", [torch.float16, torch.float32]) +def test_scann_get_documents_embeddings_by_docid(docs_dtype: torch.dtype) -> None: + """Stored embeddings should be retrievable by document ID.""" + index = indexes.ScaNN(store_embeddings=True) + + d1 = torch.arange(0, 96, dtype=torch.float32).reshape(12, 8).to(docs_dtype) + d2 = torch.arange(96, 176, dtype=torch.float32).reshape(10, 8).to(docs_dtype) + index.add_documents( + documents_ids=["d1", "d2"], + documents_embeddings=[d1, d2], + batch_size=2, + ) + + retrieved = index.get_documents_embeddings([["d2", "d1"]]) + assert len(retrieved) == 1 + assert len(retrieved[0]) == 2 + assert retrieved[0][0].dtype == ( + np.float16 if docs_dtype == torch.float16 else np.float32 + ) + assert np.array_equal(retrieved[0][0], d2.cpu().numpy()) + assert np.array_equal(retrieved[0][1], d1.cpu().numpy()) + + +def test_scann_get_documents_embeddings_requires_store_embeddings() -> None: + """Accessing embeddings without store_embeddings should raise clearly.""" + index = indexes.ScaNN(store_embeddings=False) + + with pytest.raises(NotImplementedError, match="store_embeddings=True"): + index.get_documents_embeddings([["d1"]]) + + +def test_scann_get_documents_embeddings_missing_docid_raises() -> None: + """Unknown document IDs should raise a ValueError.""" + index = indexes.ScaNN(store_embeddings=True) + index.add_documents( + documents_ids=["d1", "d2"], + documents_embeddings=[ + torch.randn(12, 8, dtype=torch.float32), + torch.randn(10, 8, dtype=torch.float32), + ], + batch_size=2, + ) + + with pytest.raises(ValueError, match="not found in index"): + index.get_documents_embeddings([["d3"]]) + + +def test_scann_save_and_load_roundtrip(tmp_path: object) -> None: + """Index saved to disk should be loadable and return identical results.""" + docs = [ + torch.randn(12, 8, dtype=torch.float32), + torch.randn(10, 8, dtype=torch.float32), + ] + doc_ids = ["d1", "d2"] + + # Build and save + index = indexes.ScaNN( + store_embeddings=True, + index_folder=str(tmp_path), + index_name="test_roundtrip", + ) + index.add_documents(documents_ids=doc_ids, documents_embeddings=docs, batch_size=2) + + query = [torch.randn(2, 8, dtype=torch.float32)] + results_before = index(queries_embeddings=query, k=2) + + # Load from disk + loaded = indexes.ScaNN( + store_embeddings=True, + index_folder=str(tmp_path), + index_name="test_roundtrip", + ) + + assert loaded._documents_added + assert set(loaded.doc_id_to_embedding_range.keys()) == {"d1", "d2"} + + results_after = loaded(queries_embeddings=query, k=2) + for before_q, after_q in zip( + results_before["documents_ids"], results_after["documents_ids"] + ): + for before_tok, after_tok in zip(before_q, after_q): + assert np.array_equal(before_tok, after_tok)