|
1 | | -class HiPCudaGraphRunner: |
2 | | - pass |
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import bisect |
| 4 | +from typing import TYPE_CHECKING, Callable |
| 5 | + |
| 6 | +import torch |
| 7 | +import tqdm |
| 8 | +from vllm.distributed import get_tensor_model_parallel_rank |
| 9 | +from vllm.distributed.parallel_state import graph_capture |
| 10 | + |
| 11 | +from sglang.srt.layers.logits_processor import ( |
| 12 | + LogitsMetadata, |
| 13 | + LogitsProcessor, |
| 14 | + LogitsProcessorOutput, |
| 15 | +) |
| 16 | +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode |
| 17 | +from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner, patch_model, clamp_position |
| 18 | + |
| 19 | +if TYPE_CHECKING: |
| 20 | + from sglang.srt.model_executor.hip_model_runner import HiPModelRunner |
| 21 | + |
| 22 | + |
| 23 | +class HiPCudaGraphRunner(CudaGraphRunner): |
| 24 | + |
| 25 | + def __init__(self, model_runner: "HiPModelRunner"): |
| 26 | + super().__init__(model_runner) |
| 27 | + |
| 28 | + def can_run(self, forward_batch: ForwardBatch): |
| 29 | + use_cached_mask = forward_batch.hip_use_cached_mask |
| 30 | + |
| 31 | + if self.enable_dp_attention: |
| 32 | + min_num_tokens, max_num_tokens = min(forward_batch.global_num_tokens), max( |
| 33 | + forward_batch.global_num_tokens |
| 34 | + ) |
| 35 | + is_bs_supported = forward_batch.can_run_dp_cuda_graph and ( |
| 36 | + (min_num_tokens == max_num_tokens and (max_num_tokens, use_cached_mask) in self.graphs) |
| 37 | + if self.disable_padding |
| 38 | + else max_num_tokens <= self.max_bs |
| 39 | + ) |
| 40 | + else: |
| 41 | + is_bs_supported = ( |
| 42 | + (forward_batch.batch_size, use_cached_mask) in self.graphs |
| 43 | + if self.disable_padding |
| 44 | + else forward_batch.batch_size <= self.max_bs |
| 45 | + ) |
| 46 | + |
| 47 | + # NOTE: cuda graph cannot handle mixed batch (encoder_len = 0) |
| 48 | + # If mixed batch cannot be supported, then encoder_lens can be removed in cuda graph |
| 49 | + # because the full_text_row_masked_out_mask tensor will always be ones |
| 50 | + is_encoder_lens_supported = ( |
| 51 | + torch.all(forward_batch.encoder_lens > 0) |
| 52 | + if self.is_encoder_decoder |
| 53 | + else True |
| 54 | + ) |
| 55 | + return is_bs_supported and is_encoder_lens_supported |
| 56 | + |
| 57 | + def capture(self): |
| 58 | + with graph_capture() as graph_capture_context: |
| 59 | + self.stream = graph_capture_context.stream |
| 60 | + capture_bs = ( |
| 61 | + tqdm.tqdm(self.capture_bs) |
| 62 | + if get_tensor_model_parallel_rank() == 0 |
| 63 | + else self.capture_bs |
| 64 | + ) |
| 65 | + for bs in capture_bs: |
| 66 | + with patch_model( |
| 67 | + self.model_runner.model, |
| 68 | + bs in self.compile_bs, |
| 69 | + bs, |
| 70 | + self.model_runner.tp_group, |
| 71 | + ) as forward: |
| 72 | + for use_cached_mask in [False, True]: |
| 73 | + ( |
| 74 | + graph, |
| 75 | + output_buffers, |
| 76 | + ) = self.capture_one_batch_size(bs, forward, use_cached_mask) |
| 77 | + self.graphs[(bs, use_cached_mask)] = graph |
| 78 | + self.output_buffers[(bs, use_cached_mask)] = output_buffers |
| 79 | + |
| 80 | + def capture_one_batch_size(self, bs: int, forward: Callable, hip_use_cached_mask: bool = False): |
| 81 | + graph = torch.cuda.CUDAGraph() |
| 82 | + stream = self.stream |
| 83 | + |
| 84 | + # Common inputs |
| 85 | + input_ids = self.input_ids[:bs] |
| 86 | + req_pool_indices = self.req_pool_indices[:bs] |
| 87 | + seq_lens = self.seq_lens[:bs] |
| 88 | + out_cache_loc = self.out_cache_loc[:bs] |
| 89 | + if self.is_encoder_decoder: |
| 90 | + encoder_lens = self.encoder_lens[:bs] |
| 91 | + else: |
| 92 | + encoder_lens = None |
| 93 | + |
| 94 | + seq_lens_sum = seq_lens.sum().item() |
| 95 | + mrope_positions = self.mrope_positions[:, :bs] |
| 96 | + |
| 97 | + if self.enable_dp_attention: |
| 98 | + global_num_tokens = [bs] * self.tp_size |
| 99 | + gathered_buffer = self.gathered_buffer[: bs * self.tp_size] |
| 100 | + else: |
| 101 | + global_num_tokens = None |
| 102 | + gathered_buffer = None |
| 103 | + |
| 104 | + # Attention backend |
| 105 | + self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph( |
| 106 | + bs, |
| 107 | + req_pool_indices, |
| 108 | + seq_lens, |
| 109 | + encoder_lens, |
| 110 | + ) |
| 111 | + |
| 112 | + # Run and capture |
| 113 | + def run_once(): |
| 114 | + forward_batch = ForwardBatch( |
| 115 | + forward_mode=ForwardMode.DECODE, |
| 116 | + batch_size=bs, |
| 117 | + input_ids=input_ids, |
| 118 | + req_pool_indices=req_pool_indices, |
| 119 | + seq_lens=seq_lens, |
| 120 | + req_to_token_pool=self.model_runner.req_to_token_pool, |
| 121 | + token_to_kv_pool=self.model_runner.token_to_kv_pool, |
| 122 | + attn_backend=self.model_runner.attn_backend, |
| 123 | + hip_metadata_cache_pool=self.model_runner.hip_metadata_cache_pool, |
| 124 | + hip_use_cached_mask=hip_use_cached_mask, |
| 125 | + out_cache_loc=out_cache_loc, |
| 126 | + seq_lens_sum=seq_lens_sum, |
| 127 | + encoder_lens=encoder_lens, |
| 128 | + return_logprob=False, |
| 129 | + top_logprobs_nums=[0] * bs, |
| 130 | + positions=clamp_position(seq_lens), |
| 131 | + mrope_positions=mrope_positions, |
| 132 | + global_num_tokens=global_num_tokens, |
| 133 | + gathered_buffer=gathered_buffer, |
| 134 | + ) |
| 135 | + logits_output = forward(input_ids, forward_batch.positions, forward_batch) |
| 136 | + return logits_output.next_token_logits |
| 137 | + |
| 138 | + for _ in range(2): |
| 139 | + torch.cuda.synchronize() |
| 140 | + self.model_runner.tp_group.barrier() |
| 141 | + |
| 142 | + run_once() |
| 143 | + |
| 144 | + torch.cuda.synchronize() |
| 145 | + self.model_runner.tp_group.barrier() |
| 146 | + |
| 147 | + torch.cuda.synchronize() |
| 148 | + self.model_runner.tp_group.barrier() |
| 149 | + |
| 150 | + with torch.cuda.graph(graph, pool=self.graph_memory_pool, stream=stream): |
| 151 | + out = run_once() |
| 152 | + |
| 153 | + torch.cuda.synchronize() |
| 154 | + self.model_runner.tp_group.barrier() |
| 155 | + |
| 156 | + self.graph_memory_pool = graph.pool() |
| 157 | + return graph, out |
| 158 | + |
| 159 | + def replay(self, forward_batch: ForwardBatch): |
| 160 | + assert forward_batch.out_cache_loc is not None |
| 161 | + raw_bs = forward_batch.batch_size |
| 162 | + |
| 163 | + # Pad |
| 164 | + if self.enable_dp_attention: |
| 165 | + index = bisect.bisect_left( |
| 166 | + self.capture_bs, max(forward_batch.global_num_tokens) |
| 167 | + ) |
| 168 | + else: |
| 169 | + index = bisect.bisect_left(self.capture_bs, raw_bs) |
| 170 | + bs = self.capture_bs[index] |
| 171 | + if bs != raw_bs: |
| 172 | + self.seq_lens.fill_(1) |
| 173 | + self.out_cache_loc.zero_() |
| 174 | + |
| 175 | + # Common inputs |
| 176 | + self.input_ids[:raw_bs].copy_(forward_batch.input_ids) |
| 177 | + self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices) |
| 178 | + self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens) |
| 179 | + self.out_cache_loc[:raw_bs].copy_(forward_batch.out_cache_loc) |
| 180 | + if self.is_encoder_decoder: |
| 181 | + self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens) |
| 182 | + if forward_batch.mrope_positions is not None: |
| 183 | + self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions) |
| 184 | + |
| 185 | + # Attention backend |
| 186 | + self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph( |
| 187 | + bs, |
| 188 | + self.req_pool_indices, |
| 189 | + self.seq_lens, |
| 190 | + forward_batch.seq_lens_sum + (bs - raw_bs), |
| 191 | + self.encoder_lens, |
| 192 | + ) |
| 193 | + |
| 194 | + # Replay |
| 195 | + key = (bs, forward_batch.hip_use_cached_mask) |
| 196 | + self.graphs[key].replay() |
| 197 | + next_token_logits = self.output_buffers[key][:raw_bs] |
| 198 | + |
| 199 | + # Extract logprobs |
| 200 | + if forward_batch.return_logprob: |
| 201 | + logits_metadata = LogitsMetadata( |
| 202 | + forward_mode=ForwardMode.DECODE, |
| 203 | + top_logprobs_nums=forward_batch.top_logprobs_nums, |
| 204 | + ) |
| 205 | + next_token_logprobs = ( |
| 206 | + LogitsProcessor.compute_temp_top_p_normalized_logprobs( |
| 207 | + next_token_logits, logits_metadata |
| 208 | + ) |
| 209 | + ) |
| 210 | + logits_output = LogitsProcessorOutput( |
| 211 | + next_token_logits=next_token_logits, |
| 212 | + next_token_logprobs=next_token_logprobs, |
| 213 | + ) |
| 214 | + return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums) |
| 215 | + if return_top_logprob: |
| 216 | + ( |
| 217 | + logits_output.output_top_logprobs_val, |
| 218 | + logits_output.output_top_logprobs_idx, |
| 219 | + ) = LogitsProcessor.get_top_logprobs( |
| 220 | + next_token_logprobs, logits_metadata |
| 221 | + )[ |
| 222 | + 2:4 |
| 223 | + ] |
| 224 | + else: |
| 225 | + logits_output = LogitsProcessorOutput( |
| 226 | + next_token_logits=next_token_logits, |
| 227 | + ) |
| 228 | + |
| 229 | + return logits_output |
0 commit comments