Skip to content

Add XTR retrieval and training support#207

Open
robro612 wants to merge 16 commits intolightonai:mainfrom
robro612:xtr-training
Open

Add XTR retrieval and training support#207
robro612 wants to merge 16 commits intolightonai:mainfrom
robro612:xtr-training

Conversation

@robro612
Copy link
Copy Markdown
Contributor

Summary

This PR adds XTR (conteXtual Token Retrieval) support to PyLate, covering both retrieval and training.

1. XTR Retrieval

XTR retrieval performs per-token approximate nearest neighbor search followed by document-level scoring with min-imputation, as an alternative to ColBERT's full late-interaction reranking.

  • pylate/retrieve/xtr.py — new XTR retriever class for XTR-style retrieval on Voyager/ScaNN indexes
  • pylate/rank/rank.py — adds score_xtr() for document-level XTR scoring with min-imputation
  • pylate/rank/__init__.py — exports score_xtr
  • pylate/retrieve/__init__.py — exports XTR retriever
  • pylate/retrieve/colbert.py — generalize type hint from Voyager | PLAID to Base index
  • Tests: tests/test_xtr_retriever.py, tests/test_xtr_scoring.py

2. XTR Training

XTR training uses a token-level top-k scoring function with z-normalization instead of ColBERT's MaxSim. This requires seeing all documents at once (for global top-k), so both loss classes are updated to support a requires_full_batch score function.

  • pylate/scores/scores.py — new xtr_scores and xtr_kd_scores functions, plus XTRScores/XTRKDScores callable classes for convenient default-k configuration. These classes set requires_full_batch = True as a class attribute, which the loss functions check to determine whether to pass the full document batch at once or chunk by group.
  • pylate/scores/__init__.py — exports new scoring functions/classes
  • pylate/losses/contrastive.py — detect requires_full_batch score metrics and pass all documents at once instead of chunking by group
  • pylate/losses/cached_contrastive.py — same requires_full_batch support for cached contrastive training
  • pylate/losses/distillation.py — add temperature parameter for XTR KD training
  • Tests: tests/test_xtr_scores.py

robro612 added 16 commits March 11, 2026 14:17
…(non-plaid).

This supports multiple imputation functions: min (default), zero, mean, percentile, or power_law. Although min is default and highly suggested.
- Fix broken doctest in xtr.py (missing add_documents call)
- Fix progress bar total using ceil division
- Use torch.isinf for robust missing-score detection in score_xtr
- Rename misleading variable neg_b -> slope in power-law imputation
- Export Base from indexes.__init__ and use public API in colbert.py
- Add unit tests for _compute_imputation_scores (rectangular, ragged, edge cases)
- Add beir_dataset_xtr.py eval script supporting both ColBERT and XTR retrieval
- Fix test_xtr_retriever.py: name -> index_name kwarg for ScaNN
…efix*

Some LI models nowadays (e.g. XTR) don't use Q/D prefix tokens. This changes the behavior of the model initialization to no longer strongly default to adding [Q] and [D] prefix tokens.

This is probably an opinionated and *breaking* change for some people's use case (if they assume e.g. during training initialization it will add the prefix tokens). Another method might be to use the currently inert `add_special_tokens` argument which currently does not impact the logic at all.
Add requires_full_batch scoring handling in Contrastive to handle XTR scoring, which requires all documents in the batch be passed at once in order to properly do the "in batch retrieval". This is automatically handled by Contrastive by checking the score function for an attribute set in scores.py.
- CachedContrastive now checks score_metric.requires_full_batch and
  passes all documents at once (stacked as (batch, N, Dt, H)) instead
  of chunking over document groups. Query mini-batching still controls
  memory. Labels adjusted to i*N for interleaved doc layout.
- Fix xtr_scores to derive mask expansion from query batch size (Qb)
  rather than document batch size (Dq), which caused IndexError when
  Qb != Dq (i.e. CachedContrastive mini-batching).
- Add regression tests for mismatched query/doc batch sizes.
- Add sweep/validation test script and SLURM array job for benchmarking
  memory and correctness across scoring/loss/batch-size configurations.
- Add design doc noting streaming top-k as future optimization.
…ce citation

- added temperature param to Distillation - found to be necessary for XTR training.
- xtr_kd_scores is a thin wrapper around existing xtr_scores function that returns only the in-example scores, fitting the Distillation interface. token scores are already required, so there's no memory overhead/minimal computation to be saved by rewriting.
- remove print statements
- contrastive labels multigpu fix
- requires_full_batch decorator for xtr_scores
- restore hpool logic in colbert.py (from overeager cherrypick)
- consolidated XTR test files
- removed non-min imputation options.
…for setting/using default k_train values easily.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant