TL;DR: Capture makes LLM inference faster by caching intermediate activations instead of just KV pairs. Think of it as a smarter cache that reduces recomputation and memory bandwidth bottlenecks.
LLM inference is memory-bound. Traditional systems only cache Key-Value (KV) pairs, but this:
- β Still requires recomputing all other layers
- β Wastes memory bandwidth fetching KV cache repeatedly
- β Can't leverage host memory effectively
Capture solves this by:
- β Caching intermediate activations (not just KV)
- β Smart mixed KV/activation caching strategy
- β Efficient host memory offloading
- β Up to 2.5x throughput improvement over vLLM
git clone https://github.com/casys-kaist/Capture.git
cd capture
pip install -e .cd capture
python scripts/capture_runner.py \
--model opt-13b \
--dataset sharegpt \
--num-prompts 100 \
--gen-len 128 \
--enable-capture \
--flag-turn-on-icache \
--host-mem-size-GB 128 \
--kvc-ratio 0.5Supported Models: OPT (6.7B, 13B, 30B, 66B)
Datasets:
sharegpt- Real-world ShareGPT conversationsdummy- Fixed-length synthetic promptsdummy-random-length- Variable-length synthetic prompts
Note: Docker setup and detailed usage instructions will be added later.
Traditional KV Cache (vLLM)
For each token at each layer:
βββββββββββββββββββββββββββββββββββββββ
β Key Tensor (K) β Stored β
β Value Tensor (V) β Stored β
βββββββββββββββββββββββββββββββββββββββ
K and V are directly used during decode
Capture's Activation Cache
For each token at each layer:
βββββββββββββββββββββββββββββββββββββββ
β Input Activation (Ac) β Stored β
βββββββββββββββββββββββββββββββββββββββ
K and V regenerated: [K V] = Ac Γ [WK WV]
Traditional Approach:
- Stores: K + V for all layers
- Block size: SKV (full size)
Capture Approach:
- Stores: Activation checkpoints only
- Block size: SACT = Β½ Γ SKV (50% memory savings!)
- K and V regenerated on-demand during decode
Instead of storing the output of attention computation (K, V tensors), Capture stores the input (activation checkpoints). Since K and V can be regenerated via a simple linear transformation, this achieves:
- 50% memory reduction per cached block
- Small recomputation cost (overlapped with weight loading from host)
- 2.19Γ throughput improvement over prior work
Capture uses a mixed strategy:
- Some tokens cached as KV blocks (no recomputation)
- Some tokens cached as ACT blocks (recompute K, V from activations)
- Optimal ratio balances PCIe bandwidth vs GPU computation
- Unified block table tracks both types
Capture extends vLLM's PagedAttention kernel to read from both KV cache AND recomputed activation buffers during decode:
// Original vLLM: Only reads from KV cache
const cache_t* k_cache = ...;
const cache_t* v_cache = ...;
// Capture: Selectively reads from KV or activation cache
const cache_t* tg_k_cache;
const cache_t* tg_v_cache;
if (buf_mapping[block_idx] == 0) {
// Read from KV cache (traditional)
tg_k_cache = k_cache;
tg_v_cache = v_cache;
} else if (buf_mapping[block_idx] == 1) {
// Read from recomputed activation buffer (Capture's innovation!)
tg_k_cache = recompute_key_cache;
tg_v_cache = recompute_value_cache;
}Why this matters:
- The attention kernel can now use partial recomputations stored in activation cache
- No need to recompute entire layers when activations are already cached
buf_mappingtensor dynamically routes each block to the appropriate cache
Implementation: See capture/csrc/attention/attention_kernels.cu:209-377
Capture includes a comprehensive benchmarking tool for evaluating performance.
cd capture
python scripts/capture_runner.py \
--model opt-13b \
--dataset sharegpt \
--num-prompts 100 \
--gen-len 128 \
--enable-capture \
--flag-turn-on-icache \
--host-mem-size-GB 128 \
--kvc-ratio 0.5Supported Datasets:
sharegpt- Real-world ShareGPT conversations (recommended for realistic benchmarks)dummy- Fixed-length synthetic prompts (for reproducibility)dummy-random-length- Variable-length synthetic prompts
Key Parameters:
--model: Model name (opt-6.7b, opt-13b, opt-30b, opt-66b, llama-7b, etc.)--dataset: Dataset to use (sharegpt, dummy, dummy-random-length)--num-prompts: Number of requests to generate--gen-len: Generation length per request--enable-capture: Enable Capture system--flag-turn-on-icache: Enable activation caching--host-mem-size-GB: Host memory budget in GB--kvc-ratio: KV cache ratio (0.0 = all activation, 1.0 = all KV)
For the complete list of parameters, see python scripts/capture_runner.py --help.
| Parameter | What it does | When to use |
|---|---|---|
enable_capture |
Turn on Capture system | Always set to True |
flag_turn_on_icache |
Enable activation caching | Your main performance knob |
host_mem_size_GB |
CPU memory budget | More = better performance |
kvc_ratio |
KV vs activation ratio | Tune based on sequence length |
flag_overlap |
Overlap data transfers | For weight offloading |
max_num_micro_batches |
Micro-batching degree | Higher = better overlap |
See all 40+ configuration options
num_gpu_kv_blocks- GPU KV cache blocksnum_gpu_act_blocks- GPU activation cache blocksnum_host_kv_blocks- Host KV cache blocksnum_host_act_blocks- Host activation cache blocksmax_num_load_tokens- Max tokens for load operations
flag_allocate_by_balance- Balance-aware block allocationflag_batching_by_balance- Balance-aware request batchingflag_async_load_merging- Async load mergingflag_equal_cache_size- Equal cache size allocation
flag_all_weight_on_host- All weights on hostnum_gpu_attn_wgt_rows- GPU attention weight rowsnum_gpu_gate_up_wgt_rows- GPU gate/up weight rows
flag_profile- Enable profilingflag_info- Print debug infoflag_validation- Validate against baseline
Full list: See Configuration Guide
βββββββββββββββββββββββββββββββ
β CaptureScheduler β
β βββββββββββββββββββββββ β
β β Micro-batch Manager β β
β β Balance-aware Batch β β
β βββββββββββββββββββββββ β
ββββββββββββ¬βββββββββββββββββββ
β
ββββββββββββββββββββΌβββββββββββββββββββ
β β β
βββββββββΌβββββββ βββββββββΌβββββββ βββββββββΌβββββββ
β GPU KV β β GPU Act β β GPU Weight β
β Cache β β Cache β β Buffer β
βββββββββ¬βββββββ βββββββββ¬βββββββ βββββββββ¬βββββββ
β β β
β async transfer β async transfer β
β β β
βββββββββΌβββββββ βββββββββΌβββββββ βββββββββΌβββββββ
β Host KV β β Host Act β β Host Weight β
β Cache β β Cache β β Storage β
ββββββββββββββββ ββββββββββββββββ ββββββββββββββββ
- CaptureScheduler (
core/capture_scheduler.py) - Micro-batch scheduling with balance-aware batching - CaptureMemory (
capture_memory.py) - Unified GPU/host memory management - CaptureWorker (
worker/capture_worker.py) - Worker with cache engine - CaptureBlockManager (
core/capture_block_manager.py) - Block-level allocation - Modified PagedAttention Kernel (
csrc/attention/attention_kernels.cu) - CUDA kernel that reads from both KV and activation caches
During decode, the modified paged attention kernel uses buf_mapping to decide where to read data:
For each block in sequence:
if buf_mapping[block] == 0:
βββΊ Read from KV cache (traditional path)
βββΊ Full Key/Value pairs
else if buf_mapping[block] == 1:
βββΊ Read from recomputed activation buffer (Capture path!)
βββΊ Cached intermediate activations
This dual-path design is the key innovation:
- Flexibility: Mix KV and activation caching per-block
- Efficiency: Use the best cache strategy for each part of the sequence
Apache 2.0 License - see LICENSE for details.
Based on vLLM (also Apache 2.0).
If you use Capture in your research, please cite:
@inproceedings{lee2025throughput,
title={Throughput-Oriented LLM Inference via KV-Activation Hybrid Caching with a Single GPU},
author={Lee, Sanghyeon and Kim, Hongbeen and Hwang, Soojin and Heo, Guseul and Noh, Minwoo and Huh, Jaehyuk},
booktitle={The 43rd IEEE International Conference on Computer Design (ICCD 2025)},
year={2025},
organization={IEEE}
}This work was supported by Institute of Information & Communications Technology Planning & Evaluation (IITP), Ministry of Science and ICT, Korea (RS2021-II211817, RS-2024-00402898).
We also acknowledge the contributions of the following projects: