Skip to content

Commit 422415b

Browse files
authored
Merge pull request #3 from gmlwns2000/hip12-offload-add-hip
update
2 parents 68a3150 + 5335cf2 commit 422415b

File tree

13 files changed

+327
-61
lines changed

13 files changed

+327
-61
lines changed

python/sglang/srt/layers/attention/hip_attention/hip_config.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,17 +43,24 @@ def __post_init__(self, parsed_json: dict | None):
4343
if parsed_json is not None:
4444
if 'second_stage_k' in parsed_json:
4545
self.second_stage_k = parsed_json['second_stage_k']
46+
parsed_json.pop('second_stage_k')
4647
if 'sliding_window_size' in parsed_json:
4748
self.sliding_window_size = parsed_json['sliding_window_size']
49+
parsed_json.pop('sliding_window_size')
4850
if 'sink_token_size' in parsed_json:
4951
self.sink_token_size = parsed_json['sink_token_size']
52+
parsed_json.pop('sink_token_size')
5053
if 'sa_extend_backend' in parsed_json:
5154
self.sa_extend_backend = parsed_json['sa_extend_backend']
55+
parsed_json.pop('sa_extend_backend')
5256
if 'stages' in parsed_json:
5357
self.stages = [
5458
ScanStage(**stage)
5559
for stage in parsed_json['stages']
5660
]
61+
parsed_json.pop('stages')
62+
if parsed_json:
63+
raise ValueError(f'Unknown keys in json: {parsed_json.keys()}')
5764

5865

5966
@dataclass
@@ -65,6 +72,8 @@ class HiPAttentionConfig:
6572
force_dense: bool = False
6673
prefill_dense_threshold: int = 8192
6774
block_sparse_block_size_q: int = 64
75+
metadata_cache_max_batch_size: int = 256
76+
mask_refresh_interval: int = 4
6877
layers: list[HiPAttentionPerLayerConfig] = field(default_factory=lambda: [
6978
HiPAttentionPerLayerConfig(parsed_json={'second_stage_k': 4096, 'sliding_window_size': 8192, 'sink_token_size': 8192}),
7079
HiPAttentionPerLayerConfig(),
@@ -77,18 +86,36 @@ def __post_init__(self, parsed_json: dict | None):
7786
if parsed_json is not None:
7887
if 'apply_v_dot' in parsed_json:
7988
self.apply_v_dot = parsed_json['apply_v_dot']
89+
parsed_json.pop('apply_v_dot')
8090
if 'dense_layers' in parsed_json:
8191
self.dense_layers = parsed_json['dense_layers']
92+
parsed_json.pop('dense_layers')
8293
if 'prefill_always_dense' in parsed_json:
8394
self.prefill_always_dense = parsed_json['prefill_always_dense']
95+
parsed_json.pop('prefill_always_dense')
8496
if 'decode_always_dense' in parsed_json:
8597
self.decode_always_dense = parsed_json['decode_always_dense']
98+
parsed_json.pop('decode_always_dense')
8699
if 'force_dense' in parsed_json:
87100
self.force_dense = parsed_json['force_dense']
101+
parsed_json.pop('force_dense')
88102
if 'prefill_dense_threshold' in parsed_json:
89103
self.prefill_dense_threshold = parsed_json['prefill_dense_threshold']
104+
parsed_json.pop('prefill_dense_threshold')
105+
if 'block_sparse_block_size_q' in parsed_json:
106+
self.block_sparse_block_size_q = parsed_json['block_sparse_block_size_q']
107+
parsed_json.pop('block_sparse_block_size_q')
108+
if 'metadata_cache_max_batch_size' in parsed_json:
109+
self.metadata_cache_max_batch_size = parsed_json['metadata_cache_max_batch_size']
110+
parsed_json.pop('metadata_cache_max_batch_size')
111+
if 'mask_refresh_interval' in parsed_json:
112+
self.mask_refresh_interval = parsed_json['mask_refresh_interval']
113+
parsed_json.pop('mask_refresh_interval')
90114
if 'layers' in parsed_json:
91115
self.layers = [
92116
HiPAttentionPerLayerConfig(parsed_json=layer)
93117
for layer in parsed_json['layers']
94118
]
119+
parsed_json.pop('layers')
120+
if parsed_json:
121+
raise ValueError(f'Unknown keys in json: {parsed_json.keys()}')
Lines changed: 229 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,229 @@
1-
class HiPCudaGraphRunner:
2-
pass
1+
from __future__ import annotations
2+
3+
import bisect
4+
from typing import TYPE_CHECKING, Callable
5+
6+
import torch
7+
import tqdm
8+
from vllm.distributed import get_tensor_model_parallel_rank
9+
from vllm.distributed.parallel_state import graph_capture
10+
11+
from sglang.srt.layers.logits_processor import (
12+
LogitsMetadata,
13+
LogitsProcessor,
14+
LogitsProcessorOutput,
15+
)
16+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
17+
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner, patch_model, clamp_position
18+
19+
if TYPE_CHECKING:
20+
from sglang.srt.model_executor.hip_model_runner import HiPModelRunner
21+
22+
23+
class HiPCudaGraphRunner(CudaGraphRunner):
24+
25+
def __init__(self, model_runner: "HiPModelRunner"):
26+
super().__init__(model_runner)
27+
28+
def can_run(self, forward_batch: ForwardBatch):
29+
use_cached_mask = forward_batch.hip_use_cached_mask
30+
31+
if self.enable_dp_attention:
32+
min_num_tokens, max_num_tokens = min(forward_batch.global_num_tokens), max(
33+
forward_batch.global_num_tokens
34+
)
35+
is_bs_supported = forward_batch.can_run_dp_cuda_graph and (
36+
(min_num_tokens == max_num_tokens and (max_num_tokens, use_cached_mask) in self.graphs)
37+
if self.disable_padding
38+
else max_num_tokens <= self.max_bs
39+
)
40+
else:
41+
is_bs_supported = (
42+
(forward_batch.batch_size, use_cached_mask) in self.graphs
43+
if self.disable_padding
44+
else forward_batch.batch_size <= self.max_bs
45+
)
46+
47+
# NOTE: cuda graph cannot handle mixed batch (encoder_len = 0)
48+
# If mixed batch cannot be supported, then encoder_lens can be removed in cuda graph
49+
# because the full_text_row_masked_out_mask tensor will always be ones
50+
is_encoder_lens_supported = (
51+
torch.all(forward_batch.encoder_lens > 0)
52+
if self.is_encoder_decoder
53+
else True
54+
)
55+
return is_bs_supported and is_encoder_lens_supported
56+
57+
def capture(self):
58+
with graph_capture() as graph_capture_context:
59+
self.stream = graph_capture_context.stream
60+
capture_bs = (
61+
tqdm.tqdm(self.capture_bs)
62+
if get_tensor_model_parallel_rank() == 0
63+
else self.capture_bs
64+
)
65+
for bs in capture_bs:
66+
with patch_model(
67+
self.model_runner.model,
68+
bs in self.compile_bs,
69+
bs,
70+
self.model_runner.tp_group,
71+
) as forward:
72+
for use_cached_mask in [False, True]:
73+
(
74+
graph,
75+
output_buffers,
76+
) = self.capture_one_batch_size(bs, forward, use_cached_mask)
77+
self.graphs[(bs, use_cached_mask)] = graph
78+
self.output_buffers[(bs, use_cached_mask)] = output_buffers
79+
80+
def capture_one_batch_size(self, bs: int, forward: Callable, hip_use_cached_mask: bool = False):
81+
graph = torch.cuda.CUDAGraph()
82+
stream = self.stream
83+
84+
# Common inputs
85+
input_ids = self.input_ids[:bs]
86+
req_pool_indices = self.req_pool_indices[:bs]
87+
seq_lens = self.seq_lens[:bs]
88+
out_cache_loc = self.out_cache_loc[:bs]
89+
if self.is_encoder_decoder:
90+
encoder_lens = self.encoder_lens[:bs]
91+
else:
92+
encoder_lens = None
93+
94+
seq_lens_sum = seq_lens.sum().item()
95+
mrope_positions = self.mrope_positions[:, :bs]
96+
97+
if self.enable_dp_attention:
98+
global_num_tokens = [bs] * self.tp_size
99+
gathered_buffer = self.gathered_buffer[: bs * self.tp_size]
100+
else:
101+
global_num_tokens = None
102+
gathered_buffer = None
103+
104+
# Attention backend
105+
self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
106+
bs,
107+
req_pool_indices,
108+
seq_lens,
109+
encoder_lens,
110+
)
111+
112+
# Run and capture
113+
def run_once():
114+
forward_batch = ForwardBatch(
115+
forward_mode=ForwardMode.DECODE,
116+
batch_size=bs,
117+
input_ids=input_ids,
118+
req_pool_indices=req_pool_indices,
119+
seq_lens=seq_lens,
120+
req_to_token_pool=self.model_runner.req_to_token_pool,
121+
token_to_kv_pool=self.model_runner.token_to_kv_pool,
122+
attn_backend=self.model_runner.attn_backend,
123+
hip_metadata_cache_pool=self.model_runner.hip_metadata_cache_pool,
124+
hip_use_cached_mask=hip_use_cached_mask,
125+
out_cache_loc=out_cache_loc,
126+
seq_lens_sum=seq_lens_sum,
127+
encoder_lens=encoder_lens,
128+
return_logprob=False,
129+
top_logprobs_nums=[0] * bs,
130+
positions=clamp_position(seq_lens),
131+
mrope_positions=mrope_positions,
132+
global_num_tokens=global_num_tokens,
133+
gathered_buffer=gathered_buffer,
134+
)
135+
logits_output = forward(input_ids, forward_batch.positions, forward_batch)
136+
return logits_output.next_token_logits
137+
138+
for _ in range(2):
139+
torch.cuda.synchronize()
140+
self.model_runner.tp_group.barrier()
141+
142+
run_once()
143+
144+
torch.cuda.synchronize()
145+
self.model_runner.tp_group.barrier()
146+
147+
torch.cuda.synchronize()
148+
self.model_runner.tp_group.barrier()
149+
150+
with torch.cuda.graph(graph, pool=self.graph_memory_pool, stream=stream):
151+
out = run_once()
152+
153+
torch.cuda.synchronize()
154+
self.model_runner.tp_group.barrier()
155+
156+
self.graph_memory_pool = graph.pool()
157+
return graph, out
158+
159+
def replay(self, forward_batch: ForwardBatch):
160+
assert forward_batch.out_cache_loc is not None
161+
raw_bs = forward_batch.batch_size
162+
163+
# Pad
164+
if self.enable_dp_attention:
165+
index = bisect.bisect_left(
166+
self.capture_bs, max(forward_batch.global_num_tokens)
167+
)
168+
else:
169+
index = bisect.bisect_left(self.capture_bs, raw_bs)
170+
bs = self.capture_bs[index]
171+
if bs != raw_bs:
172+
self.seq_lens.fill_(1)
173+
self.out_cache_loc.zero_()
174+
175+
# Common inputs
176+
self.input_ids[:raw_bs].copy_(forward_batch.input_ids)
177+
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
178+
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
179+
self.out_cache_loc[:raw_bs].copy_(forward_batch.out_cache_loc)
180+
if self.is_encoder_decoder:
181+
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
182+
if forward_batch.mrope_positions is not None:
183+
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
184+
185+
# Attention backend
186+
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
187+
bs,
188+
self.req_pool_indices,
189+
self.seq_lens,
190+
forward_batch.seq_lens_sum + (bs - raw_bs),
191+
self.encoder_lens,
192+
)
193+
194+
# Replay
195+
key = (bs, forward_batch.hip_use_cached_mask)
196+
self.graphs[key].replay()
197+
next_token_logits = self.output_buffers[key][:raw_bs]
198+
199+
# Extract logprobs
200+
if forward_batch.return_logprob:
201+
logits_metadata = LogitsMetadata(
202+
forward_mode=ForwardMode.DECODE,
203+
top_logprobs_nums=forward_batch.top_logprobs_nums,
204+
)
205+
next_token_logprobs = (
206+
LogitsProcessor.compute_temp_top_p_normalized_logprobs(
207+
next_token_logits, logits_metadata
208+
)
209+
)
210+
logits_output = LogitsProcessorOutput(
211+
next_token_logits=next_token_logits,
212+
next_token_logprobs=next_token_logprobs,
213+
)
214+
return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums)
215+
if return_top_logprob:
216+
(
217+
logits_output.output_top_logprobs_val,
218+
logits_output.output_top_logprobs_idx,
219+
) = LogitsProcessor.get_top_logprobs(
220+
next_token_logprobs, logits_metadata
221+
)[
222+
2:4
223+
]
224+
else:
225+
logits_output = LogitsProcessorOutput(
226+
next_token_logits=next_token_logits,
227+
)
228+
229+
return logits_output

python/sglang/srt/layers/attention/hip_attention/hip_radix_attention.py

Lines changed: 12 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,6 @@
3535
from sglang.srt.model_executor.hip_model_runner import HiPModelRunner
3636
from sglang.srt.layers.attention.hip_attention.hip_config import HiPAttentionConfig
3737

38-
# from hip.models.hip_attention.attention2_draft_sampling_extend import dual_stage_quadratic_hip_attention
39-
# from hip import HiPAttentionArgs
4038
from hip.models.hip_attention.gen3.attention_extend import dual_stage_quadratic_hip_attention
4139
from hip.models.hip_attention.gen3.attention_metadata import HiPAttentionArgs
4240
from hip.models.hip_attention.gen3.uvm_gpu_cache import HiPOffloadCache
@@ -260,7 +258,7 @@ def forward_extend(
260258
)
261259

262260
logger.debug(f'HiP attention is used in prompting (layer {layer.layer_id})!', stacklevel=0)
263-
261+
264262
is_offload_cache = isinstance(forward_batch.token_to_kv_pool, MHATokenToHiPOffloadKVPool)
265263

266264
if is_offload_cache:
@@ -269,10 +267,10 @@ def forward_extend(
269267
assert v is not None
270268
if save_kv_cache:
271269
forward_batch.token_to_kv_pool.set_kv_buffer(
272-
layer,
273-
cache_loc,
274-
k,
275-
v,
270+
layer,
271+
cache_loc,
272+
k,
273+
v,
276274
async_copy=False
277275
)
278276
k_cache = v_cache = None
@@ -337,7 +335,6 @@ def forward_decode(
337335
else forward_batch.encoder_out_cache_loc
338336
)
339337

340-
341338
require_dense = (
342339
layer.layer_id in self.hip_config.dense_layers or
343340
self.hip_config.decode_always_dense or
@@ -364,9 +361,10 @@ def forward_decode(
364361
k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
365362
offload_cache = None
366363

367-
metadata = forward_batch.hip_metadata_cache_pool.get_hip_metadata_cache(
368-
layer.layer_id, q.shape[0], forward_batch.batch_size)
369-
#metadata = None
364+
metadata = None
365+
if forward_batch.hip_use_cached_mask:
366+
metadata = forward_batch.hip_metadata_cache_pool.get_hip_metadata_cache(
367+
layer.layer_id, q.shape[0], forward_batch.batch_size)
370368

371369
o, metadata = self.forward_paged_hip(
372370
query=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
@@ -387,19 +385,8 @@ def forward_decode(
387385
is_dense=require_dense,
388386
)
389387

390-
#print("q shape", q.shape, layer.tp_q_head_num, layer.head_dim)
391-
#print("k_cache shape", k_cache.shape)
392-
#print("v_cache shape", v_cache.shape)
393-
#print("positions", forward_batch.positions)
394-
#print("seq_lens", forward_batch.seq_lens)
395-
#print("metadata")
396-
#print("indices", metadata.indices.shape)
397-
#print("ks", metadata.ks.shape)
398-
#print("ks_count", metadata.ks_count.shape)
399-
#print("ks_start_end", metadata.ks_start_end.shape)
400-
401388
forward_batch.hip_metadata_cache_pool.set_hip_metadata_cache(
402-
layer.layer_id, q.shape[0], forward_batch.batch_size, cache_loc, metadata)
389+
layer.layer_id, q.shape[0], forward_batch.batch_size, metadata)
403390

404391
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
405392

@@ -436,7 +423,7 @@ def forward_paged_hip(
436423
dst_seq_len = N // batch_size
437424

438425
query = query.view(batch_size, dst_seq_len, num_heads, hidden_dims)
439-
426+
440427
if k_cache is not None:
441428
N_PAGE, num_heads_kv, hidden_dims_kv = k_cache.shape
442429
assert v_cache.shape == k_cache.shape
@@ -471,7 +458,7 @@ def forward_paged_hip(
471458
rope_sin=layer.rope_sin,
472459

473460
logit_softcap=layer.logit_cap,
474-
461+
475462
second_stage_k=layer_config.second_stage_k,
476463
stages=layer_config.stages,
477464
model_context_length=layer.orig_context_len,

0 commit comments

Comments
 (0)