Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 17 additions & 48 deletions python/sglang/multimodal_gen/runtime/models/dits/flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@
_apply_rotary_emb,
)
from sglang.multimodal_gen.runtime.models.dits.base import CachableDiT
from sglang.multimodal_gen.runtime.models.dits.utils import (
delete_projection_layers,
fuse_linear_projections,
)
from sglang.multimodal_gen.runtime.platforms import (
AttentionBackendEnum,
current_platform,
Expand Down Expand Up @@ -169,51 +173,20 @@ def fuse_projections(self):
if self.fused_projections:
return

device = self.to_q.weight.data.device
dtype = self.to_q.weight.data.dtype

concatenated_weights = torch.cat(
[self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data]
self.to_qkv = fuse_linear_projections(
self.to_q, self.to_k, self.to_v, self.use_bias, ReplicatedLinear
)
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]

self.to_qkv = ReplicatedLinear(in_features, out_features, bias=self.use_bias)
self.to_qkv.weight.data = concatenated_weights.to(device=device, dtype=dtype)
if self.use_bias:
concatenated_bias = torch.cat(
[self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data]
)
self.to_qkv.bias.data = concatenated_bias.to(device=device, dtype=dtype)
delete_projection_layers(self, ["to_q", "to_k", "to_v"])

if self.added_kv_proj_dim is not None:
concatenated_weights = torch.cat(
[
self.add_q_proj.weight.data,
self.add_k_proj.weight.data,
self.add_v_proj.weight.data,
]
self.to_added_qkv = fuse_linear_projections(
self.add_q_proj,
self.add_k_proj,
self.add_v_proj,
self.added_proj_bias,
ReplicatedLinear,
)
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]

self.to_added_qkv = ReplicatedLinear(
in_features, out_features, bias=self.added_proj_bias
)
self.to_added_qkv.weight.data = concatenated_weights.to(
device=device, dtype=dtype
)
if self.added_proj_bias:
concatenated_bias = torch.cat(
[
self.add_q_proj.bias.data,
self.add_k_proj.bias.data,
self.add_v_proj.bias.data,
]
)
self.to_added_qkv.bias.data = concatenated_bias.to(
device=device, dtype=dtype
)
delete_projection_layers(self, ["add_q_proj", "add_k_proj", "add_v_proj"])

self.fused_projections = True

Expand Down Expand Up @@ -530,13 +503,9 @@ def __init__(self, config: FluxConfig, hf_config: dict[str, Any]) -> None:
)

def fuse_qkv_projections(self):
for block in self.transformer_blocks:
if hasattr(block.attn, "fuse_projections") and getattr(
block.attn, "_supports_qkv_fusion", True
):
block.attn.fuse_projections()

for block in self.single_transformer_blocks:
for block in list(self.transformer_blocks) + list(
self.single_transformer_blocks
):
if hasattr(block.attn, "fuse_projections") and getattr(
block.attn, "_supports_qkv_fusion", True
):
Expand Down
65 changes: 17 additions & 48 deletions python/sglang/multimodal_gen/runtime/models/dits/flux_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@
from sglang.multimodal_gen.runtime.layers.layernorm import RMSNorm
from sglang.multimodal_gen.runtime.layers.rotary_embedding import _apply_rotary_emb
from sglang.multimodal_gen.runtime.models.dits.base import CachableDiT
from sglang.multimodal_gen.runtime.models.dits.utils import (
delete_projection_layers,
fuse_linear_projections,
)
from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger

Expand Down Expand Up @@ -190,51 +194,20 @@ def fuse_projections(self):
if self.fused_projections:
return

device = self.to_q.weight.data.device
dtype = self.to_q.weight.data.dtype

concatenated_weights = torch.cat(
[self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data]
self.to_qkv = fuse_linear_projections(
self.to_q, self.to_k, self.to_v, self.use_bias, torch.nn.Linear
)
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]

self.to_qkv = torch.nn.Linear(in_features, out_features, bias=self.use_bias)
self.to_qkv.weight.data = concatenated_weights.to(device=device, dtype=dtype)
if self.use_bias:
concatenated_bias = torch.cat(
[self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data]
)
self.to_qkv.bias.data = concatenated_bias.to(device=device, dtype=dtype)
delete_projection_layers(self, ["to_q", "to_k", "to_v"])

if self.added_kv_proj_dim is not None:
concatenated_weights = torch.cat(
[
self.add_q_proj.weight.data,
self.add_k_proj.weight.data,
self.add_v_proj.weight.data,
]
)
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]

self.to_added_qkv = torch.nn.Linear(
in_features, out_features, bias=self.added_proj_bias
)
self.to_added_qkv.weight.data = concatenated_weights.to(
device=device, dtype=dtype
self.to_added_qkv = fuse_linear_projections(
self.add_q_proj,
self.add_k_proj,
self.add_v_proj,
self.added_proj_bias,
torch.nn.Linear,
)
if self.added_proj_bias:
concatenated_bias = torch.cat(
[
self.add_q_proj.bias.data,
self.add_k_proj.bias.data,
self.add_v_proj.bias.data,
]
)
self.to_added_qkv.bias.data = concatenated_bias.to(
device=device, dtype=dtype
)
delete_projection_layers(self, ["add_q_proj", "add_k_proj", "add_v_proj"])

self.fused_projections = True

Expand Down Expand Up @@ -785,13 +758,9 @@ def __init__(self, config: FluxConfig, hf_config: dict[str, Any]):
self.gradient_checkpointing = False

def fuse_qkv_projections(self):
for block in self.transformer_blocks:
if hasattr(block.attn, "fuse_projections") and getattr(
block.attn, "_supports_qkv_fusion", True
):
block.attn.fuse_projections()

for block in self.single_transformer_blocks:
for block in list(self.transformer_blocks) + list(
self.single_transformer_blocks
):
if hasattr(block.attn, "fuse_projections") and getattr(
block.attn, "_supports_qkv_fusion", True
):
Expand Down
87 changes: 78 additions & 9 deletions python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,54 @@
fuse_scale_shift_kernel,
)
from sglang.multimodal_gen.runtime.models.dits.base import CachableDiT
from sglang.multimodal_gen.runtime.models.dits.utils import (
delete_projection_layers,
fuse_linear_projections,
)
from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger

logger = init_logger(__name__) # pylint: disable=invalid-name


def _get_projections(
attn: "QwenImageCrossAttention", hidden_states, encoder_hidden_states=None
):
img_query, _ = attn.to_q(hidden_states)
img_key, _ = attn.to_k(hidden_states)
img_value, _ = attn.to_v(hidden_states)

txt_query = txt_key = txt_value = None
if encoder_hidden_states is not None and hasattr(attn, "add_q_proj"):
txt_query, _ = attn.add_q_proj(encoder_hidden_states)
txt_key, _ = attn.add_k_proj(encoder_hidden_states)
txt_value, _ = attn.add_v_proj(encoder_hidden_states)

return img_query, img_key, img_value, txt_query, txt_key, txt_value


def _get_fused_projections(
attn: "QwenImageCrossAttention", hidden_states, encoder_hidden_states=None
):
img_qkv, _ = attn.to_qkv(hidden_states)
img_query, img_key, img_value = img_qkv.chunk(3, dim=-1)

txt_query = txt_key = txt_value = None
if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"):
txt_qkv, _ = attn.to_added_qkv(encoder_hidden_states)
txt_query, txt_key, txt_value = txt_qkv.chunk(3, dim=-1)

return img_query, img_key, img_value, txt_query, txt_key, txt_value


def _get_qkv_projections(
attn: "QwenImageCrossAttention", hidden_states, encoder_hidden_states=None
):
if attn.fused_projections:
return _get_fused_projections(attn, hidden_states, encoder_hidden_states)
return _get_projections(attn, hidden_states, encoder_hidden_states)


class QwenTimestepProjEmbeddings(nn.Module):
def __init__(self, embedding_dim):
super().__init__()
Expand Down Expand Up @@ -218,6 +260,7 @@ def _compute_video_freqs(


class QwenImageCrossAttention(nn.Module):
_supports_qkv_fusion = True

def __init__(
self,
Expand Down Expand Up @@ -294,6 +337,31 @@ def __init__(
},
)

self.fused_projections = False
self.added_kv_proj_dim_val = added_kv_proj_dim

@torch.no_grad()
def fuse_projections(self):
if self.fused_projections:
return

self.to_qkv = fuse_linear_projections(
self.to_q, self.to_k, self.to_v, use_bias=False, linear_cls=ReplicatedLinear
)
delete_projection_layers(self, ["to_q", "to_k", "to_v"])

if self.added_kv_proj_dim_val is not None and hasattr(self, "add_q_proj"):
self.to_added_qkv = fuse_linear_projections(
self.add_q_proj,
self.add_k_proj,
self.add_v_proj,
use_bias=True,
linear_cls=ReplicatedLinear,
)
delete_projection_layers(self, ["add_q_proj", "add_k_proj", "add_v_proj"])

self.fused_projections = True

def forward(
self,
hidden_states: torch.Tensor,
Expand All @@ -303,15 +371,9 @@ def forward(
):
seq_len_txt = encoder_hidden_states.shape[1]

# Compute QKV for image stream (sample projections)
img_query, _ = self.to_q(hidden_states)
img_key, _ = self.to_k(hidden_states)
img_value, _ = self.to_v(hidden_states)

# Compute QKV for text stream (context projections)
txt_query, _ = self.add_q_proj(encoder_hidden_states)
txt_key, _ = self.add_k_proj(encoder_hidden_states)
txt_value, _ = self.add_v_proj(encoder_hidden_states)
img_query, img_key, img_value, txt_query, txt_key, txt_value = (
_get_qkv_projections(self, hidden_states, encoder_hidden_states)
)

# Reshape for multi-head attention
img_query = img_query.unflatten(-1, (self.num_heads, -1))
Expand Down Expand Up @@ -562,6 +624,13 @@ def __init__(
self.inner_dim, patch_size * patch_size * self.out_channels, bias=True
)

def fuse_qkv_projections(self):
for block in self.transformer_blocks:
if hasattr(block.attn, "fuse_projections") and getattr(
block.attn, "_supports_qkv_fusion", True
):
block.attn.fuse_projections()

def forward(
self,
hidden_states: torch.Tensor,
Expand Down
43 changes: 43 additions & 0 deletions python/sglang/multimodal_gen/runtime/models/dits/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from typing import Union

import torch
import torch.nn as nn

from sglang.multimodal_gen.runtime.layers.linear import ReplicatedLinear


def fuse_linear_projections(
q_proj: Union[nn.Linear, ReplicatedLinear],
k_proj: Union[nn.Linear, ReplicatedLinear],
v_proj: Union[nn.Linear, ReplicatedLinear],
use_bias: bool,
linear_cls: type = None,
) -> Union[nn.Linear, ReplicatedLinear]:
device = q_proj.weight.data.device
dtype = q_proj.weight.data.dtype

concatenated_weights = torch.cat(
[q_proj.weight.data, k_proj.weight.data, v_proj.weight.data]
)
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]

if linear_cls is None:
linear_cls = type(q_proj)

fused_layer = linear_cls(in_features, out_features, bias=use_bias)
fused_layer.weight.data = concatenated_weights.to(device=device, dtype=dtype)

if use_bias:
concatenated_bias = torch.cat(
[q_proj.bias.data, k_proj.bias.data, v_proj.bias.data]
)
fused_layer.bias.data = concatenated_bias.to(device=device, dtype=dtype)

return fused_layer


def delete_projection_layers(module: nn.Module, layer_names: list[str]) -> None:
for name in layer_names:
if hasattr(module, name):
delattr(module, name)
Loading