Skip to content
Merged
168 changes: 159 additions & 9 deletions vllm/v1/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,33 @@

if TYPE_CHECKING:
from vllm.v1.core.scheduler import SchedulerOutput

from vllm_hpu_extension.bucketing.common import get_bucketing_context

logger = init_logger(__name__)

_TYPE_CACHE = {}


def setup_profiler(warmup, active):
schedule = torch.profiler.schedule(wait=0,
warmup=warmup,
active=active,
repeat=1)
activities = [
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.HPU
]
profiler = torch.profiler.profile(
schedule=schedule,
activities=activities,
on_trace_ready=torch.profiler.tensorboard_trace_handler('.',
use_gzip=True),
record_shapes=False,
with_stack=True)
return profiler


class PhaseType(Enum):
PREFILL = 'prefill'
PREFIX_PREFILL = 'prefix_prefill'
Expand Down Expand Up @@ -1469,6 +1489,7 @@ def _execute_model_generic(self,
logits_indices,
kv_caches,
warmup_mode=False):

# FORWARD.
batch_size = token_ids.size(0)
seq_len = self._seq_len(attn_metadata)
Expand Down Expand Up @@ -2028,18 +2049,147 @@ def warmup_graphs(self,

return total_mem, total_batch_seq, captured_all

def _add_dummy_request(self, requests, num_scheduled_tokens,
num_computed_tokens, total_tokens,
scheduled_tokens):
from vllm.sampling_params import SamplingParams
from vllm.v1.core.sched.output import NewRequestData

num_blocks = round_up(total_tokens, 128) // 128
prompt_token_ids = list(range(total_tokens))

req_id = f'req-{len(requests)}'
block_ids = [0] * num_blocks
sampling_params = SamplingParams(temperature=0.0)

req = NewRequestData(
req_id=req_id,
prompt_token_ids=prompt_token_ids,
mm_inputs=[],
mm_hashes=[],
mm_positions=[],
sampling_params=sampling_params,
block_ids=block_ids,
num_computed_tokens=num_computed_tokens,
lora_request=None,
)
requests.append(req)
num_scheduled_tokens[req_id] = scheduled_tokens

@staticmethod
def _generate_seq_lengths(num_samples, num_blocks, block_size):
assert num_samples <= num_blocks
blocks = [num_blocks // num_samples] * num_samples
missing_blocks = num_blocks - sum(blocks)
for i in range(missing_blocks):
blocks[i] += 1
seq_lengths = [b * block_size - 1 for b in blocks]
return seq_lengths

def _execute_dummy_scenario(self, prompt_cfg, decode_cfg):
from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput
requests: list[NewRequestData] = []
scheduled_tokens: dict[str, int] = {}

if prompt_cfg:
prompt_bs, prompt_query_len, prompt_blocks = prompt_cfg
prompt_ctx_len = prompt_blocks * self.block_size
prompt_total_tokens = prompt_query_len + prompt_ctx_len
for _ in range(prompt_bs):
self._add_dummy_request(requests,
scheduled_tokens,
num_computed_tokens=prompt_ctx_len,
total_tokens=prompt_total_tokens,
scheduled_tokens=prompt_query_len)
if decode_cfg:
decode_bs, decode_blocks = decode_cfg
decode_seq_lengths = self._generate_seq_lengths(
decode_bs, decode_blocks, self.block_size)
for dsl in decode_seq_lengths:
self._add_dummy_request(requests,
scheduled_tokens,
num_computed_tokens=dsl,
total_tokens=dsl,
scheduled_tokens=1)
sched_output = SchedulerOutput(
scheduled_new_reqs=requests,
scheduled_cached_reqs=[],
num_scheduled_tokens=scheduled_tokens,
total_num_scheduled_tokens=sum(scheduled_tokens.values()),
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
cleanup = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={},
total_num_scheduled_tokens=0,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
finished_req_ids=set(req.req_id for req in requests),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
self.execute_model(sched_output)
self.execute_model(cleanup)

def _generate_profiling(self, prompt_cfg, decode_cfg):
steps = 3
profiler = setup_profiler(warmup=steps - 1, active=1)
torch.hpu.synchronize()
profiler.start()
for _ in range(steps):
self._execute_dummy_scenario(prompt_cfg, decode_cfg)
torch.hpu.synchronize()
profiler.step()
profiler.stop()

@staticmethod
def _parse_profile_cfg(profile_cfg):
if profile_cfg:
return tuple(map(int, profile_cfg.split(',')))
return None

@staticmethod
def _parse_legacy_profile_cfg(profile_cfg):
if profile_cfg:
cfg = profile_cfg.split('_')
assert cfg[0] in ['prompt', 'decode']
return (cfg[0], int(cfg[1]), int(cfg[2]), cfg[3] == 't')
return None

def _read_profiling_cfg(self):
prompt_cfg = self._parse_profile_cfg(
os.environ.get('VLLM_PROFILE_PROMPT', None))
decode_cfg = self._parse_profile_cfg(
os.environ.get('VLLM_PROFILE_DECODE', None))
legacy_cfg = self._parse_legacy_profile_cfg(
os.environ.get('VLLM_PT_PROFILE', None))
if legacy_cfg and not (prompt_cfg or decode_cfg):
phase, bs, seq_or_blocks, use_graphs = legacy_cfg
assert use_graphs != self.model_config.enforce_eager, \
"'use_graphs' is out of sync with model config. " \
"Either change the flag or change vllm engine parameters"
if phase == 'prompt':
prompt_cfg = (bs, seq_or_blocks, 0)
else:
decode_cfg = (bs, seq_or_blocks)
return prompt_cfg, decode_cfg

@torch.inference_mode()
def warmup_model(self) -> None:
kv_caches = self.kv_caches
if profile := os.environ.get('VLLM_PT_PROFILE', None):
phase, bs, seq_len, graph = profile.split('_')
is_prompt = phase == 'prompt'
graphs = graph == 't'
if graphs:
self.graphed_buckets.add((int(bs), int(seq_len), is_prompt))
#self.warmup_scenario(int(bs), int(seq_len), is_prompt, kv_caches,
# True)
prompt_profile_cfg, decode_profile_cfg = self._read_profiling_cfg()
if prompt_profile_cfg or decode_profile_cfg:
self._generate_profiling(prompt_profile_cfg, decode_profile_cfg)
raise AssertionError("Finished profiling")
kv_caches = self.kv_caches
max_blocks = kv_caches[0][0].size(0)
self.bucketing_ctx.generate_decode_buckets(max_blocks)

Expand Down