diff --git a/vllm/v1/worker/hpu_model_runner.py b/vllm/v1/worker/hpu_model_runner.py index 9cd24de86c48..5991c7f9b014 100644 --- a/vllm/v1/worker/hpu_model_runner.py +++ b/vllm/v1/worker/hpu_model_runner.py @@ -46,6 +46,7 @@ if TYPE_CHECKING: from vllm.v1.core.scheduler import SchedulerOutput + from vllm_hpu_extension.bucketing.common import get_bucketing_context logger = init_logger(__name__) @@ -53,6 +54,25 @@ _TYPE_CACHE = {} +def setup_profiler(warmup, active): + schedule = torch.profiler.schedule(wait=0, + warmup=warmup, + active=active, + repeat=1) + activities = [ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.HPU + ] + profiler = torch.profiler.profile( + schedule=schedule, + activities=activities, + on_trace_ready=torch.profiler.tensorboard_trace_handler('.', + use_gzip=True), + record_shapes=False, + with_stack=True) + return profiler + + class PhaseType(Enum): PREFILL = 'prefill' PREFIX_PREFILL = 'prefix_prefill' @@ -1469,6 +1489,7 @@ def _execute_model_generic(self, logits_indices, kv_caches, warmup_mode=False): + # FORWARD. batch_size = token_ids.size(0) seq_len = self._seq_len(attn_metadata) @@ -2028,18 +2049,147 @@ def warmup_graphs(self, return total_mem, total_batch_seq, captured_all + def _add_dummy_request(self, requests, num_scheduled_tokens, + num_computed_tokens, total_tokens, + scheduled_tokens): + from vllm.sampling_params import SamplingParams + from vllm.v1.core.sched.output import NewRequestData + + num_blocks = round_up(total_tokens, 128) // 128 + prompt_token_ids = list(range(total_tokens)) + + req_id = f'req-{len(requests)}' + block_ids = [0] * num_blocks + sampling_params = SamplingParams(temperature=0.0) + + req = NewRequestData( + req_id=req_id, + prompt_token_ids=prompt_token_ids, + mm_inputs=[], + mm_hashes=[], + mm_positions=[], + sampling_params=sampling_params, + block_ids=block_ids, + num_computed_tokens=num_computed_tokens, + lora_request=None, + ) + requests.append(req) + num_scheduled_tokens[req_id] = scheduled_tokens + + @staticmethod + def _generate_seq_lengths(num_samples, num_blocks, block_size): + assert num_samples <= num_blocks + blocks = [num_blocks // num_samples] * num_samples + missing_blocks = num_blocks - sum(blocks) + for i in range(missing_blocks): + blocks[i] += 1 + seq_lengths = [b * block_size - 1 for b in blocks] + return seq_lengths + + def _execute_dummy_scenario(self, prompt_cfg, decode_cfg): + from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput + requests: list[NewRequestData] = [] + scheduled_tokens: dict[str, int] = {} + + if prompt_cfg: + prompt_bs, prompt_query_len, prompt_blocks = prompt_cfg + prompt_ctx_len = prompt_blocks * self.block_size + prompt_total_tokens = prompt_query_len + prompt_ctx_len + for _ in range(prompt_bs): + self._add_dummy_request(requests, + scheduled_tokens, + num_computed_tokens=prompt_ctx_len, + total_tokens=prompt_total_tokens, + scheduled_tokens=prompt_query_len) + if decode_cfg: + decode_bs, decode_blocks = decode_cfg + decode_seq_lengths = self._generate_seq_lengths( + decode_bs, decode_blocks, self.block_size) + for dsl in decode_seq_lengths: + self._add_dummy_request(requests, + scheduled_tokens, + num_computed_tokens=dsl, + total_tokens=dsl, + scheduled_tokens=1) + sched_output = SchedulerOutput( + scheduled_new_reqs=requests, + scheduled_cached_reqs=[], + num_scheduled_tokens=scheduled_tokens, + total_num_scheduled_tokens=sum(scheduled_tokens.values()), + scheduled_spec_decode_tokens={}, + scheduled_encoder_inputs={}, + num_common_prefix_blocks=0, + finished_req_ids=set(), + free_encoder_input_ids=[], + structured_output_request_ids={}, + grammar_bitmask=None, + ) + cleanup = SchedulerOutput( + scheduled_new_reqs=[], + scheduled_cached_reqs=[], + num_scheduled_tokens={}, + total_num_scheduled_tokens=0, + scheduled_spec_decode_tokens={}, + scheduled_encoder_inputs={}, + num_common_prefix_blocks=0, + finished_req_ids=set(req.req_id for req in requests), + free_encoder_input_ids=[], + structured_output_request_ids={}, + grammar_bitmask=None, + ) + self.execute_model(sched_output) + self.execute_model(cleanup) + + def _generate_profiling(self, prompt_cfg, decode_cfg): + steps = 3 + profiler = setup_profiler(warmup=steps - 1, active=1) + torch.hpu.synchronize() + profiler.start() + for _ in range(steps): + self._execute_dummy_scenario(prompt_cfg, decode_cfg) + torch.hpu.synchronize() + profiler.step() + profiler.stop() + + @staticmethod + def _parse_profile_cfg(profile_cfg): + if profile_cfg: + return tuple(map(int, profile_cfg.split(','))) + return None + + @staticmethod + def _parse_legacy_profile_cfg(profile_cfg): + if profile_cfg: + cfg = profile_cfg.split('_') + assert cfg[0] in ['prompt', 'decode'] + return (cfg[0], int(cfg[1]), int(cfg[2]), cfg[3] == 't') + return None + + def _read_profiling_cfg(self): + prompt_cfg = self._parse_profile_cfg( + os.environ.get('VLLM_PROFILE_PROMPT', None)) + decode_cfg = self._parse_profile_cfg( + os.environ.get('VLLM_PROFILE_DECODE', None)) + legacy_cfg = self._parse_legacy_profile_cfg( + os.environ.get('VLLM_PT_PROFILE', None)) + if legacy_cfg and not (prompt_cfg or decode_cfg): + phase, bs, seq_or_blocks, use_graphs = legacy_cfg + assert use_graphs != self.model_config.enforce_eager, \ + "'use_graphs' is out of sync with model config. " \ + "Either change the flag or change vllm engine parameters" + if phase == 'prompt': + prompt_cfg = (bs, seq_or_blocks, 0) + else: + decode_cfg = (bs, seq_or_blocks) + return prompt_cfg, decode_cfg + @torch.inference_mode() def warmup_model(self) -> None: - kv_caches = self.kv_caches - if profile := os.environ.get('VLLM_PT_PROFILE', None): - phase, bs, seq_len, graph = profile.split('_') - is_prompt = phase == 'prompt' - graphs = graph == 't' - if graphs: - self.graphed_buckets.add((int(bs), int(seq_len), is_prompt)) - #self.warmup_scenario(int(bs), int(seq_len), is_prompt, kv_caches, - # True) + prompt_profile_cfg, decode_profile_cfg = self._read_profiling_cfg() + if prompt_profile_cfg or decode_profile_cfg: + self._generate_profiling(prompt_profile_cfg, decode_profile_cfg) raise AssertionError("Finished profiling") + kv_caches = self.kv_caches max_blocks = kv_caches[0][0].size(0) self.bucketing_ctx.generate_decode_buckets(max_blocks)