diff --git a/python/sglang/multimodal_gen/runtime/models/dits/flux.py b/python/sglang/multimodal_gen/runtime/models/dits/flux.py index d2a2f0304e0..3efb8623b78 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/flux.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/flux.py @@ -43,6 +43,10 @@ _apply_rotary_emb, ) from sglang.multimodal_gen.runtime.models.dits.base import CachableDiT +from sglang.multimodal_gen.runtime.models.dits.utils import ( + delete_projection_layers, + fuse_linear_projections, +) from sglang.multimodal_gen.runtime.platforms import ( AttentionBackendEnum, current_platform, @@ -169,51 +173,20 @@ def fuse_projections(self): if self.fused_projections: return - device = self.to_q.weight.data.device - dtype = self.to_q.weight.data.dtype - - concatenated_weights = torch.cat( - [self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data] + self.to_qkv = fuse_linear_projections( + self.to_q, self.to_k, self.to_v, self.use_bias, ReplicatedLinear ) - in_features = concatenated_weights.shape[1] - out_features = concatenated_weights.shape[0] - - self.to_qkv = ReplicatedLinear(in_features, out_features, bias=self.use_bias) - self.to_qkv.weight.data = concatenated_weights.to(device=device, dtype=dtype) - if self.use_bias: - concatenated_bias = torch.cat( - [self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data] - ) - self.to_qkv.bias.data = concatenated_bias.to(device=device, dtype=dtype) + delete_projection_layers(self, ["to_q", "to_k", "to_v"]) if self.added_kv_proj_dim is not None: - concatenated_weights = torch.cat( - [ - self.add_q_proj.weight.data, - self.add_k_proj.weight.data, - self.add_v_proj.weight.data, - ] + self.to_added_qkv = fuse_linear_projections( + self.add_q_proj, + self.add_k_proj, + self.add_v_proj, + self.added_proj_bias, + ReplicatedLinear, ) - in_features = concatenated_weights.shape[1] - out_features = concatenated_weights.shape[0] - - self.to_added_qkv = ReplicatedLinear( - in_features, out_features, bias=self.added_proj_bias - ) - self.to_added_qkv.weight.data = concatenated_weights.to( - device=device, dtype=dtype - ) - if self.added_proj_bias: - concatenated_bias = torch.cat( - [ - self.add_q_proj.bias.data, - self.add_k_proj.bias.data, - self.add_v_proj.bias.data, - ] - ) - self.to_added_qkv.bias.data = concatenated_bias.to( - device=device, dtype=dtype - ) + delete_projection_layers(self, ["add_q_proj", "add_k_proj", "add_v_proj"]) self.fused_projections = True @@ -530,13 +503,9 @@ def __init__(self, config: FluxConfig, hf_config: dict[str, Any]) -> None: ) def fuse_qkv_projections(self): - for block in self.transformer_blocks: - if hasattr(block.attn, "fuse_projections") and getattr( - block.attn, "_supports_qkv_fusion", True - ): - block.attn.fuse_projections() - - for block in self.single_transformer_blocks: + for block in list(self.transformer_blocks) + list( + self.single_transformer_blocks + ): if hasattr(block.attn, "fuse_projections") and getattr( block.attn, "_supports_qkv_fusion", True ): diff --git a/python/sglang/multimodal_gen/runtime/models/dits/flux_2.py b/python/sglang/multimodal_gen/runtime/models/dits/flux_2.py index 290765c9314..bbb9d9cc970 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/flux_2.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/flux_2.py @@ -29,6 +29,10 @@ from sglang.multimodal_gen.runtime.layers.layernorm import RMSNorm from sglang.multimodal_gen.runtime.layers.rotary_embedding import _apply_rotary_emb from sglang.multimodal_gen.runtime.models.dits.base import CachableDiT +from sglang.multimodal_gen.runtime.models.dits.utils import ( + delete_projection_layers, + fuse_linear_projections, +) from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger @@ -190,51 +194,20 @@ def fuse_projections(self): if self.fused_projections: return - device = self.to_q.weight.data.device - dtype = self.to_q.weight.data.dtype - - concatenated_weights = torch.cat( - [self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data] + self.to_qkv = fuse_linear_projections( + self.to_q, self.to_k, self.to_v, self.use_bias, torch.nn.Linear ) - in_features = concatenated_weights.shape[1] - out_features = concatenated_weights.shape[0] - - self.to_qkv = torch.nn.Linear(in_features, out_features, bias=self.use_bias) - self.to_qkv.weight.data = concatenated_weights.to(device=device, dtype=dtype) - if self.use_bias: - concatenated_bias = torch.cat( - [self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data] - ) - self.to_qkv.bias.data = concatenated_bias.to(device=device, dtype=dtype) + delete_projection_layers(self, ["to_q", "to_k", "to_v"]) if self.added_kv_proj_dim is not None: - concatenated_weights = torch.cat( - [ - self.add_q_proj.weight.data, - self.add_k_proj.weight.data, - self.add_v_proj.weight.data, - ] - ) - in_features = concatenated_weights.shape[1] - out_features = concatenated_weights.shape[0] - - self.to_added_qkv = torch.nn.Linear( - in_features, out_features, bias=self.added_proj_bias - ) - self.to_added_qkv.weight.data = concatenated_weights.to( - device=device, dtype=dtype + self.to_added_qkv = fuse_linear_projections( + self.add_q_proj, + self.add_k_proj, + self.add_v_proj, + self.added_proj_bias, + torch.nn.Linear, ) - if self.added_proj_bias: - concatenated_bias = torch.cat( - [ - self.add_q_proj.bias.data, - self.add_k_proj.bias.data, - self.add_v_proj.bias.data, - ] - ) - self.to_added_qkv.bias.data = concatenated_bias.to( - device=device, dtype=dtype - ) + delete_projection_layers(self, ["add_q_proj", "add_k_proj", "add_v_proj"]) self.fused_projections = True @@ -785,13 +758,9 @@ def __init__(self, config: FluxConfig, hf_config: dict[str, Any]): self.gradient_checkpointing = False def fuse_qkv_projections(self): - for block in self.transformer_blocks: - if hasattr(block.attn, "fuse_projections") and getattr( - block.attn, "_supports_qkv_fusion", True - ): - block.attn.fuse_projections() - - for block in self.single_transformer_blocks: + for block in list(self.transformer_blocks) + list( + self.single_transformer_blocks + ): if hasattr(block.attn, "fuse_projections") and getattr( block.attn, "_supports_qkv_fusion", True ): diff --git a/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py b/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py index 989d6d5286b..1dd0781a846 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py @@ -22,12 +22,54 @@ fuse_scale_shift_kernel, ) from sglang.multimodal_gen.runtime.models.dits.base import CachableDiT +from sglang.multimodal_gen.runtime.models.dits.utils import ( + delete_projection_layers, + fuse_linear_projections, +) from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger logger = init_logger(__name__) # pylint: disable=invalid-name +def _get_projections( + attn: "QwenImageCrossAttention", hidden_states, encoder_hidden_states=None +): + img_query, _ = attn.to_q(hidden_states) + img_key, _ = attn.to_k(hidden_states) + img_value, _ = attn.to_v(hidden_states) + + txt_query = txt_key = txt_value = None + if encoder_hidden_states is not None and hasattr(attn, "add_q_proj"): + txt_query, _ = attn.add_q_proj(encoder_hidden_states) + txt_key, _ = attn.add_k_proj(encoder_hidden_states) + txt_value, _ = attn.add_v_proj(encoder_hidden_states) + + return img_query, img_key, img_value, txt_query, txt_key, txt_value + + +def _get_fused_projections( + attn: "QwenImageCrossAttention", hidden_states, encoder_hidden_states=None +): + img_qkv, _ = attn.to_qkv(hidden_states) + img_query, img_key, img_value = img_qkv.chunk(3, dim=-1) + + txt_query = txt_key = txt_value = None + if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"): + txt_qkv, _ = attn.to_added_qkv(encoder_hidden_states) + txt_query, txt_key, txt_value = txt_qkv.chunk(3, dim=-1) + + return img_query, img_key, img_value, txt_query, txt_key, txt_value + + +def _get_qkv_projections( + attn: "QwenImageCrossAttention", hidden_states, encoder_hidden_states=None +): + if attn.fused_projections: + return _get_fused_projections(attn, hidden_states, encoder_hidden_states) + return _get_projections(attn, hidden_states, encoder_hidden_states) + + class QwenTimestepProjEmbeddings(nn.Module): def __init__(self, embedding_dim): super().__init__() @@ -218,6 +260,7 @@ def _compute_video_freqs( class QwenImageCrossAttention(nn.Module): + _supports_qkv_fusion = True def __init__( self, @@ -294,6 +337,31 @@ def __init__( }, ) + self.fused_projections = False + self.added_kv_proj_dim_val = added_kv_proj_dim + + @torch.no_grad() + def fuse_projections(self): + if self.fused_projections: + return + + self.to_qkv = fuse_linear_projections( + self.to_q, self.to_k, self.to_v, use_bias=False, linear_cls=ReplicatedLinear + ) + delete_projection_layers(self, ["to_q", "to_k", "to_v"]) + + if self.added_kv_proj_dim_val is not None and hasattr(self, "add_q_proj"): + self.to_added_qkv = fuse_linear_projections( + self.add_q_proj, + self.add_k_proj, + self.add_v_proj, + use_bias=True, + linear_cls=ReplicatedLinear, + ) + delete_projection_layers(self, ["add_q_proj", "add_k_proj", "add_v_proj"]) + + self.fused_projections = True + def forward( self, hidden_states: torch.Tensor, @@ -303,15 +371,9 @@ def forward( ): seq_len_txt = encoder_hidden_states.shape[1] - # Compute QKV for image stream (sample projections) - img_query, _ = self.to_q(hidden_states) - img_key, _ = self.to_k(hidden_states) - img_value, _ = self.to_v(hidden_states) - - # Compute QKV for text stream (context projections) - txt_query, _ = self.add_q_proj(encoder_hidden_states) - txt_key, _ = self.add_k_proj(encoder_hidden_states) - txt_value, _ = self.add_v_proj(encoder_hidden_states) + img_query, img_key, img_value, txt_query, txt_key, txt_value = ( + _get_qkv_projections(self, hidden_states, encoder_hidden_states) + ) # Reshape for multi-head attention img_query = img_query.unflatten(-1, (self.num_heads, -1)) @@ -562,6 +624,13 @@ def __init__( self.inner_dim, patch_size * patch_size * self.out_channels, bias=True ) + def fuse_qkv_projections(self): + for block in self.transformer_blocks: + if hasattr(block.attn, "fuse_projections") and getattr( + block.attn, "_supports_qkv_fusion", True + ): + block.attn.fuse_projections() + def forward( self, hidden_states: torch.Tensor, diff --git a/python/sglang/multimodal_gen/runtime/models/dits/utils.py b/python/sglang/multimodal_gen/runtime/models/dits/utils.py new file mode 100644 index 00000000000..74b2d4d6729 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/models/dits/utils.py @@ -0,0 +1,43 @@ +from typing import Union + +import torch +import torch.nn as nn + +from sglang.multimodal_gen.runtime.layers.linear import ReplicatedLinear + + +def fuse_linear_projections( + q_proj: Union[nn.Linear, ReplicatedLinear], + k_proj: Union[nn.Linear, ReplicatedLinear], + v_proj: Union[nn.Linear, ReplicatedLinear], + use_bias: bool, + linear_cls: type = None, +) -> Union[nn.Linear, ReplicatedLinear]: + device = q_proj.weight.data.device + dtype = q_proj.weight.data.dtype + + concatenated_weights = torch.cat( + [q_proj.weight.data, k_proj.weight.data, v_proj.weight.data] + ) + in_features = concatenated_weights.shape[1] + out_features = concatenated_weights.shape[0] + + if linear_cls is None: + linear_cls = type(q_proj) + + fused_layer = linear_cls(in_features, out_features, bias=use_bias) + fused_layer.weight.data = concatenated_weights.to(device=device, dtype=dtype) + + if use_bias: + concatenated_bias = torch.cat( + [q_proj.bias.data, k_proj.bias.data, v_proj.bias.data] + ) + fused_layer.bias.data = concatenated_bias.to(device=device, dtype=dtype) + + return fused_layer + + +def delete_projection_layers(module: nn.Module, layer_names: list[str]) -> None: + for name in layer_names: + if hasattr(module, name): + delattr(module, name)