diff --git a/README_GAUDI.md b/README_GAUDI.md index 9a2c9c2190a9..a8d29b3843aa 100644 --- a/README_GAUDI.md +++ b/README_GAUDI.md @@ -599,7 +599,7 @@ Please refer to this [collection](https://github.com/HabanaAI/Gaudi-tutorials/tr ## Split QKV projection -This is an experimental performance optimization implemented for selected models: LLama, Mixtral, Granite and GPTBigCode. It allows splitting the QKV projection into three separate operations - Q, K, and V projections. This approach is particularly beneficial in scenarios where models have high compute requirements, as it enables better pipelining of workloads between MME's and TPC's engines. For example, models with large batch sizes or long sequence lengths can see improved throughput due to reduced contention on compute resources. More information can be found in the [Gaudi Architecture](https://docs.habana.ai/en/v1.20.1/Gaudi_Overview/Gaudi_Architecture.html) page. To apply this optimization, use the `--split-qkv` argument for online mode or set `split_qkv=True` in offline mode. +This is an experimental performance optimization implemented for selected models: LLama, Mixtral, Granite, Gemma3 and GPTBigCode. It allows splitting the QKV projection into three separate operations - Q, K, and V projections. This approach is particularly beneficial in scenarios where models have high compute requirements, as it enables better pipelining of workloads between MME's and TPC's engines. For example, models with large batch sizes or long sequence lengths can see improved throughput due to reduced contention on compute resources. More information can be found in the [Gaudi Architecture](https://docs.habana.ai/en/v1.20.1/Gaudi_Overview/Gaudi_Architecture.html) page. To apply this optimization, use the `--split-qkv` argument for online mode or set `split_qkv=True` in offline mode. > [!NOTE] > Splitting QKV projection can also degrade the performance for cases with low compute, i.e. low batch size, short sequence lengths or using tensor parallelism. It should always be verified in a particular scenario using a profiling tool such as [perfetto.habana.ai](https://perfetto.habana.ai/#!/viewer) or by analyzing execution traces to ensure optimal performance. diff --git a/docs/getting_started/installation/ai_accelerator/hpu-gaudi.inc.md b/docs/getting_started/installation/ai_accelerator/hpu-gaudi.inc.md index 6cab0b80e6ad..313c743b9074 100644 --- a/docs/getting_started/installation/ai_accelerator/hpu-gaudi.inc.md +++ b/docs/getting_started/installation/ai_accelerator/hpu-gaudi.inc.md @@ -479,7 +479,7 @@ Enabling of Multi-Step Scheduling is recommended for better decode performance. ## Split QKV projection -This is an experimental performance optimization implemented for selected models: LLama, Mixtral, Granite and GPTBigCode. It allows splitting the QKV projection into three separate operations - Q, K, and V projections. This approach is particularly beneficial in scenarios where models have high compute requirements, as it enables better pipelining of workloads between MME's and TPC's engines. For example, models with large batch sizes or long sequence lengths can see improved throughput due to reduced contention on compute resources. More information can be found in the [Gaudi Architecture](https://docs.habana.ai/en/v1.20.1/Gaudi_Overview/Gaudi_Architecture.html) page. To apply this optimization, use the `--split-qkv` argument for online mode or set `split_qkv=True` in offline mode. +This is an experimental performance optimization implemented for selected models: LLama, Mixtral, Granite, Gemma3 and GPTBigCode. It allows splitting the QKV projection into three separate operations - Q, K, and V projections. This approach is particularly beneficial in scenarios where models have high compute requirements, as it enables better pipelining of workloads between MME's and TPC's engines. For example, models with large batch sizes or long sequence lengths can see improved throughput due to reduced contention on compute resources. More information can be found in the [Gaudi Architecture](https://docs.habana.ai/en/v1.20.1/Gaudi_Overview/Gaudi_Architecture.html) page. To apply this optimization, use the `--split-qkv` argument for online mode or set `split_qkv=True` in offline mode. > [!NOTE] > Splitting QKV projection can also degrade the performance for cases with low compute, i.e. low batch size, short sequence lengths or using tensor parallelism. It should always be verified in a particular scenario using a profiling tool such as [perfetto.habana.ai](https://perfetto.habana.ai/#!/viewer) or by analyzing execution traces to ensure optimal performance. diff --git a/vllm/model_executor/models/gemma3.py b/vllm/model_executor/models/gemma3.py index f3ccf58d64ee..9edacbabf364 100644 --- a/vllm/model_executor/models/gemma3.py +++ b/vllm/model_executor/models/gemma3.py @@ -32,7 +32,8 @@ from vllm.model_executor.layers.layernorm import GemmaRMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, - RowParallelLinear) + RowParallelLinear, + SplitQKVParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope @@ -126,16 +127,28 @@ def __init__(self, self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = config.query_pre_attn_scalar**-0.5 - - self.qkv_proj = QKVParallelLinear( - hidden_size, - self.head_dim, - self.total_num_heads, - self.total_num_kv_heads, - bias=config.attention_bias, - quant_config=quant_config, - prefix=f"{prefix}.qkv_proj", - ) + self.split_qkv = cache_config.split_qkv + + if self.split_qkv: + self.qkv_proj = SplitQKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=config.attention_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + else: + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=config.attention_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, @@ -189,8 +202,12 @@ def forward( hidden_states: torch.Tensor, **kwargs, ) -> torch.Tensor: - qkv, _ = self.qkv_proj(hidden_states) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + if self.split_qkv: + q, k, v, _ = self.qkv_proj(hidden_states) + else: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], + dim=-1) q = q.unflatten(-1, (self.num_heads, self.head_dim)) q = self.q_norm(q) @@ -379,6 +396,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) + self.split_qkv = cache_config.split_qkv def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: # NOTE(woosuk): Only apply the normalizer to the output of @@ -420,14 +438,25 @@ def forward( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] + if not self.split_qkv: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + else: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj.q_proj", "q_proj", "q"), + ("qkv_proj.k_proj", "k_proj", "k"), + ("qkv_proj.v_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: @@ -452,7 +481,11 @@ def load_weights(self, weights: Iterable[tuple[str, continue param = params_dict[name] weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) + if self.split_qkv and (shard_id == "q" or shard_id == "v" + or shard_id == "k"): + weight_loader(param, loaded_weight) + else: + weight_loader(param, loaded_weight, shard_id) break else: # Skip loading extra bias for GPTQ models.