Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions vllm/forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def set_forward_context(attn_metadata: Any,
dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group)
cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_tensor, dim=0)

assert current_platform is not None, "current_platform is None" # noqa
if current_platform.is_hpu(): # noqa
num_experts_per_tok = 0
num_experts_per_tok = getattr(
Expand Down
29 changes: 13 additions & 16 deletions vllm/model_executor/models/mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1064,22 +1064,19 @@ def _attention_with_mask_hpu(
# Skip writing kv-cache for the initial profiling run.
if kv_cache is not None and isinstance(kv_cache, tuple):
assert self.attn.backend == _Backend.HPU_ATTN
# During cross-attention decode, key & value will be None,
# we don't need to cache them.
if (k is not None) and (v is not None):
from vllm.attention.ops.hpu_paged_attn import HPUPagedAttention
key_cache, value_cache = HPUPagedAttention.split_kv_cache(
kv_cache, self.num_local_key_value_heads, self.head_dim)
cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode])
cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode])
slot_mapping = torch.cat([
attn_metadata.cross_slot_mapping[s:e]
for s, e in kv_range_for_decode
])
key_cache = self.attn.impl.k_cache(cached_k, key_cache,
slot_mapping)
value_cache = self.attn.impl.v_cache(cached_v, value_cache,
slot_mapping)
from vllm.attention.ops.hpu_paged_attn import HPUPagedAttention
key_cache, value_cache = HPUPagedAttention.split_kv_cache(
kv_cache, self.num_local_key_value_heads, self.head_dim)
cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode])
cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode])
slot_mapping = torch.cat([
attn_metadata.cross_slot_mapping[s:e]
for s, e in kv_range_for_decode
])
key_cache = self.attn.impl.k_cache(cached_k, key_cache,
slot_mapping)
value_cache = self.attn.impl.v_cache(cached_v, value_cache,
slot_mapping)

q_len = q.shape[0]
kv_len = k.shape[0]
Expand Down
2 changes: 1 addition & 1 deletion vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2747,7 +2747,7 @@ def create_dummy_seq_group_metadata(self,
sampling_params = None
else:
sampling_params = SamplingParams(temperature=temperature)
num_blocks = math.ceil(seq_len / self.block_size)
num_blocks = math.ceil(seq_len / self.block_size)
seq_len = max(seq_len, 1)
computed_block_nums = None
if is_prompt:
Expand Down