Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
e92d432
Fix gemma3 workload execution failure
shepark Apr 24, 2025
bd75109
add run scripts
ssarkar2 May 21, 2025
90a3ae1
minor
ssarkar2 May 23, 2025
9434197
Update sliding_window attention
jiminha May 28, 2025
0364786
Update run scripts
ssarkar2 May 28, 2025
8caf102
Update sliding_window mask logic for lazy mode
jiminha May 28, 2025
bff7983
Fix long prompt accuracy issue
jiminha May 30, 2025
dccd67e
Change back to Eager mode for Vision prompt
jiminha May 30, 2025
45aaede
Remove unnecessary files
jiminha May 30, 2025
7df1811
Remove test file
jiminha May 30, 2025
e4b0397
Merge branch 'habana_main' into jha/sliding_window_gemma3
ssarkar2 May 30, 2025
6088039
Enable bs>1
ssarkar2 Jun 3, 2025
8b13980
enable hpu graph model
maktukmak Jun 4, 2025
a9e5a7d
Add temporary test scripts
ssarkar2 Jun 4, 2025
f783955
Fix for missing image
ssarkar2 Jun 5, 2025
1297154
Bring back +1
ssarkar2 Jun 5, 2025
be41114
Switch to lazy+hpugraphs, add v0 mode
ssarkar2 Jun 5, 2025
74e4cfb
Fix masks. Remove cross attn between images
ssarkar2 Jun 6, 2025
347e965
Script for variable batches
ssarkar2 Jun 6, 2025
a29d537
Do vision+combining before text mdoel fwd
ssarkar2 Jun 10, 2025
c9c5757
wrap vision and projector in hpu graphs
ssarkar2 Jun 10, 2025
5af6870
vectorized mask generation
maktukmak Jun 10, 2025
000b4e0
Revert "wrap vision and projector in hpu graphs"
ssarkar2 Jun 11, 2025
658442d
Revert "Do vision+combining before text mdoel fwd"
ssarkar2 Jun 11, 2025
61a3e2f
Fixing the earlier commit which was reverted
ssarkar2 Jun 11, 2025
39e0f52
bring back reverted commit
ssarkar2 Jun 11, 2025
661f59a
Fix accuracy issue with repeat words for long prompts
jiminha Jun 21, 2025
affc7a7
Change parameter check for intereleaved sliding_window
jiminha Jun 23, 2025
994de89
Remove all test files
jiminha Jun 23, 2025
ad492f8
Merge remote-tracking branch 'origin/habana_main' into jha/sliding_wi…
jiminha Jun 23, 2025
1ceca57
Merge branch 'habana_main' into jha/sliding_window_gemma3
jiminha Jun 23, 2025
805df55
Fix error from merge
jiminha Jun 23, 2025
d531412
Fix pre-commit errors
jiminha Jun 23, 2025
f99d76a
Pre-commit fix for the list warning
jiminha Jun 24, 2025
9e61d32
pre-commit error fix
jiminha Jun 24, 2025
7dd6513
Pre-commit error fix
jiminha Jun 24, 2025
f102fac
Move prompt attn_mask generation to model runner
jiminha Jun 24, 2025
275ec1e
Only build sliding_window_mask for text only input for gemma
jiminha Jun 25, 2025
023524d
pre-commit fix
jiminha Jun 25, 2025
a6f7647
Combine mask_update for both image/text input for gemma3
jiminha Jun 26, 2025
8a41b21
Update gemma model files for Attn
jiminha Jun 30, 2025
05f219d
Merge remote-tracking branch 'remotes/origin/habana_main' into jha/sl…
jiminha Jun 30, 2025
b7d28f8
Add Unittest for Gemma3
jiminha Jul 1, 2025
5aac21e
Remove duplicated gemma test case
jiminha Jul 1, 2025
25f6c99
Updated based on the review comment
jiminha Jul 1, 2025
6ced16f
Merge remote-tracking branch 'remotes/origin/habana_main' into jha/sl…
jiminha Jul 1, 2025
ad99447
Precommit error fix
jiminha Jul 1, 2025
d1a6468
Fix test failures
jiminha Jul 1, 2025
5d0ff5b
add split qkv to gemma3
skaulintel Jul 2, 2025
e693b6a
fix some precommit
skaulintel Jul 3, 2025
7661c2a
add to readme
skaulintel Jul 3, 2025
f9b2937
Merge branch 'habana_main' into dev/skaul_gemma_splitqkv
skaulintel Jul 7, 2025
e6fa5ed
remove print statements
skaulintel Jul 7, 2025
aeb0b09
Merge branch 'habana_main' into dev/skaul_gemma_splitqkv
libinta Jul 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README_GAUDI.md
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,7 @@ Please refer to this [collection](https://github.com/HabanaAI/Gaudi-tutorials/tr

## Split QKV projection

This is an experimental performance optimization implemented for selected models: LLama, Mixtral, Granite and GPTBigCode. It allows splitting the QKV projection into three separate operations - Q, K, and V projections. This approach is particularly beneficial in scenarios where models have high compute requirements, as it enables better pipelining of workloads between MME's and TPC's engines. For example, models with large batch sizes or long sequence lengths can see improved throughput due to reduced contention on compute resources. More information can be found in the [Gaudi Architecture](https://docs.habana.ai/en/v1.20.1/Gaudi_Overview/Gaudi_Architecture.html) page. To apply this optimization, use the `--split-qkv` argument for online mode or set `split_qkv=True` in offline mode.
This is an experimental performance optimization implemented for selected models: LLama, Mixtral, Granite, Gemma3 and GPTBigCode. It allows splitting the QKV projection into three separate operations - Q, K, and V projections. This approach is particularly beneficial in scenarios where models have high compute requirements, as it enables better pipelining of workloads between MME's and TPC's engines. For example, models with large batch sizes or long sequence lengths can see improved throughput due to reduced contention on compute resources. More information can be found in the [Gaudi Architecture](https://docs.habana.ai/en/v1.20.1/Gaudi_Overview/Gaudi_Architecture.html) page. To apply this optimization, use the `--split-qkv` argument for online mode or set `split_qkv=True` in offline mode.

> [!NOTE]
> Splitting QKV projection can also degrade the performance for cases with low compute, i.e. low batch size, short sequence lengths or using tensor parallelism. It should always be verified in a particular scenario using a profiling tool such as [perfetto.habana.ai](https://perfetto.habana.ai/#!/viewer) or by analyzing execution traces to ensure optimal performance.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ Enabling of Multi-Step Scheduling is recommended for better decode performance.

## Split QKV projection

This is an experimental performance optimization implemented for selected models: LLama, Mixtral, Granite and GPTBigCode. It allows splitting the QKV projection into three separate operations - Q, K, and V projections. This approach is particularly beneficial in scenarios where models have high compute requirements, as it enables better pipelining of workloads between MME's and TPC's engines. For example, models with large batch sizes or long sequence lengths can see improved throughput due to reduced contention on compute resources. More information can be found in the [Gaudi Architecture](https://docs.habana.ai/en/v1.20.1/Gaudi_Overview/Gaudi_Architecture.html) page. To apply this optimization, use the `--split-qkv` argument for online mode or set `split_qkv=True` in offline mode.
This is an experimental performance optimization implemented for selected models: LLama, Mixtral, Granite, Gemma3 and GPTBigCode. It allows splitting the QKV projection into three separate operations - Q, K, and V projections. This approach is particularly beneficial in scenarios where models have high compute requirements, as it enables better pipelining of workloads between MME's and TPC's engines. For example, models with large batch sizes or long sequence lengths can see improved throughput due to reduced contention on compute resources. More information can be found in the [Gaudi Architecture](https://docs.habana.ai/en/v1.20.1/Gaudi_Overview/Gaudi_Architecture.html) page. To apply this optimization, use the `--split-qkv` argument for online mode or set `split_qkv=True` in offline mode.

> [!NOTE]
> Splitting QKV projection can also degrade the performance for cases with low compute, i.e. low batch size, short sequence lengths or using tensor parallelism. It should always be verified in a particular scenario using a profiling tool such as [perfetto.habana.ai](https://perfetto.habana.ai/#!/viewer) or by analyzing execution traces to ensure optimal performance.
Expand Down
77 changes: 55 additions & 22 deletions vllm/model_executor/models/gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
RowParallelLinear,
SplitQKVParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
Expand Down Expand Up @@ -126,16 +127,28 @@ def __init__(self,
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = config.query_pre_attn_scalar**-0.5

self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=config.attention_bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.split_qkv = cache_config.split_qkv

if self.split_qkv:
self.qkv_proj = SplitQKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=config.attention_bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
else:
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=config.attention_bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
Expand Down Expand Up @@ -189,8 +202,12 @@ def forward(
hidden_states: torch.Tensor,
**kwargs,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if self.split_qkv:
q, k, v, _ = self.qkv_proj(hidden_states)
else:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
dim=-1)

q = q.unflatten(-1, (self.num_heads, self.head_dim))
q = self.q_norm(q)
Expand Down Expand Up @@ -379,6 +396,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
self.split_qkv = cache_config.split_qkv

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
# NOTE(woosuk): Only apply the normalizer to the output of
Expand Down Expand Up @@ -420,14 +438,25 @@ def forward(

def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
if not self.split_qkv:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
else:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj.q_proj", "q_proj", "q"),
("qkv_proj.k_proj", "k_proj", "k"),
("qkv_proj.v_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]

params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
Expand All @@ -452,7 +481,11 @@ def load_weights(self, weights: Iterable[tuple[str,
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
if self.split_qkv and (shard_id == "q" or shard_id == "v"
or shard_id == "k"):
weight_loader(param, loaded_weight)
else:
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
Expand Down