Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/sglang/srt/configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,7 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
"Mistral3ForConditionalGeneration",
"MultiModalityCausalLM",
"MllamaForConditionalGeneration",
"Qwen2AudioForConditionalGeneration",
"Qwen2VLForConditionalGeneration",
"Qwen2_5_VLForConditionalGeneration",
"KimiVLForConditionalGeneration",
Expand Down
34 changes: 34 additions & 0 deletions python/sglang/srt/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class SeparatorStyle(IntEnum):
METAMATH = auto()
DeepSeekVL2 = auto()
QWEN2_VL_EMBED = auto()
QWEN2_AUDIO = auto()
GEMMA3 = auto()
MPT = auto()

Expand Down Expand Up @@ -350,6 +351,23 @@ def get_prompt(self) -> str:
else:
ret += role
return ret
elif self.sep_style == SeparatorStyle.QWEN2_AUDIO:
ret = "" if system_prompt == "" else system_prompt + self.sep

counter = 1
for role, message in self.messages:
if message:
while self.audio_token in message:
message = message.replace(
self.audio_token, self.audio_token.format(idx=counter), 1
)
counter += 1

ret += role + "\n" + message + self.sep
else:
ret += role + "\n"

return ret
else:
raise ValueError(f"Invalid style: {self.sep_style}")

Expand Down Expand Up @@ -904,6 +922,20 @@ def generate_chat_conv(
)


register_conv_template(
Conversation(
name="qwen2-audio",
system_template="<|im_start|>system\n{system_message}",
system_message="You are a helpful assistant.",
roles=("<|im_start|>user", "<|im_start|>assistant"),
sep="<|im_end|>\n",
sep_style=SeparatorStyle.QWEN2_AUDIO,
stop_str=["<|im_end|>"],
audio_token="Audio {idx}: <|audio_bos|><|AUDIO|><|audio_eos|>\n",
)
)


@register_conv_template_matching_function
def match_internvl(model_path: str):
if re.search(r"internvl2_5", model_path, re.IGNORECASE):
Expand Down Expand Up @@ -956,6 +988,8 @@ def match_qwen_chat_ml(model_path: str):
return "gme-qwen2-vl"
if re.search(r"qwen.*vl", model_path, re.IGNORECASE):
return "qwen2-vl"
if re.search(r"qwen.*audio", model_path, re.IGNORECASE):
return "qwen2-audio"
if re.search(
r"llava-v1\.6-34b|llava-v1\.6-yi-34b|llava-next-video-34b|llava-onevision-qwen2",
model_path,
Expand Down
94 changes: 94 additions & 0 deletions python/sglang/srt/managers/multimodal_processors/qwen_audio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import re
from typing import List, Union

import torch

from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor,
MultimodalSpecialTokens,
)
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.models.qwen2_audio import Qwen2AudioForConditionalGeneration


class Qwen2AudioMultimodalProcessor(BaseMultimodalProcessor):
models = [Qwen2AudioForConditionalGeneration]

def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
self.AUDIO_TOKEN = "<|audio_bos|><|AUDIO|><|audio_eos|>"
self.AUDIO_TOKEN_REGEX = re.compile(
r"<\|audio_bos\|>(?:<\|AUDIO\|>)+<\|audio_eos\|>"
)

async def process_mm_data_async(
self,
image_data: List[Union[str, bytes]],
input_text,
request_obj,
max_req_input_len,
**kwargs,
):
audio_data = request_obj.audio_data
if not isinstance(audio_data, list):
audio_data = [audio_data]

base_output = self.load_mm_data(
prompt=input_text,
max_req_input_len=max_req_input_len,
audio_data=audio_data,
multimodal_tokens=MultimodalSpecialTokens(
audio_token=self.AUDIO_TOKEN,
audio_token_regex=self.AUDIO_TOKEN_REGEX,
),
)
if base_output is None:
return None

res = self.process_mm_data(
input_text=base_output.input_text,
audio=base_output.audios,
)

# Collect special token ids
tokenizer = self._processor.tokenizer
audio_start_id = tokenizer.convert_tokens_to_ids("<|audio_bos|>")
audio_token_id = tokenizer.convert_tokens_to_ids("<|AUDIO|>")
audio_end_id = tokenizer.convert_tokens_to_ids("<|audio_eos|>")

items = []
input_ids = res["input_ids"].flatten()

if (
"input_features" in res
and res["input_features"] is not None
and len(res["input_features"]) != 0
):
if audio_start_id is not None and audio_end_id is not None:
audio_offsets = self.get_mm_items_offset_by_pair(
input_ids=input_ids,
mm_start_id=audio_start_id,
mm_end_id=audio_end_id,
)
else:
audio_offsets = None

input_lengths = res["feature_attention_mask"].sum(dim=-1)
input_lengths = (input_lengths - 1) // 2 + 1
output_lengths = (input_lengths - 2) // 2 + 1

item = MultimodalDataItem(
audio_features=res["input_features"],
audio_feature_lens=output_lengths,
audio_offsets=audio_offsets,
modality=Modality.AUDIO,
)
items += [item]

return {
"mm_items": items,
"input_ids": input_ids.tolist(),
"audio_start_id": audio_start_id,
"audio_token_id": audio_token_id,
"audio_end_id": audio_end_id,
}
1 change: 1 addition & 0 deletions python/sglang/srt/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,7 @@ def __init__(
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
)

else:
# ranks other than the last rank will have a placeholder layer
self.lm_head = PPMissingLayer()
Expand Down
200 changes: 200 additions & 0 deletions python/sglang/srt/models/qwen2_audio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/1d45d90e5d1552eccb6d8cc9b7bba283ccefb808/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
import logging
import math
from functools import lru_cache, partial
from typing import Any, Iterable, List, Optional, Tuple, Type, TypedDict

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from transformers import AutoTokenizer, Qwen2AudioEncoderConfig, Qwen2Config
from transformers.activations import ACT2FN
from transformers.models.qwen2_audio.configuration_qwen2_audio import Qwen2AudioConfig
from transformers.models.qwen2_audio.modeling_qwen2_audio import (
Qwen2AudioEncoder,
Qwen2AudioMultiModalProjector,
)

from sglang.srt.hf_transformers_utils import get_processor
from sglang.srt.layers.activation import QuickGELU
from sglang.srt.layers.attention.vision import VisionAttention
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.utils import get_layer_id
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternMultimodalTokens,
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
from sglang.srt.utils import add_prefix

logger = logging.getLogger(__name__)


class Qwen2AudioForConditionalGeneration(nn.Module):
# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}

def __init__(
self,
config: Qwen2AudioConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()

self.config = config

if getattr(self.config, "audio_config", None) is None:
self.config.audio_config = Qwen2AudioEncoderConfig(
self.config._name_or_path
)

if getattr(self.config, "text_config", None) is None:
self.config.text_config = Qwen2Config(self.config._name_or_path)

self.audio_tower = Qwen2AudioEncoder(
config.audio_config,
)
self.multi_modal_projector = Qwen2AudioMultiModalProjector(config)
self.language_model = Qwen2ForCausalLM(
config.text_config, quant_config, prefix=add_prefix("model", prefix)
)

def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
# Get all special token IDs for audio
audio_token_id: int = getattr(
mm_inputs, "audio_token_id", mm_inputs.im_token_id
)

pattern = MultiModalityDataPaddingPatternMultimodalTokens([audio_token_id])
return pattern.pad_input_tokens(input_ids, mm_inputs)

def get_audio_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
# Extract audio features from input items
input_features = torch.cat([item.audio_features for item in items], dim=0).type(
self.audio_tower.dtype
)

audio_embeds = self.audio_tower(input_features).last_hidden_state
audio_embeds = self.multi_modal_projector(audio_embeds)

audio_feature_lens = torch.cat([item.audio_feature_lens for item in items])
new_embeds = []
for i, d in zip(audio_feature_lens, audio_embeds):
new_embeds.append(d[: i.item()])

return torch.cat(new_embeds, dim=0)

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
**kwargs: Any,
) -> torch.Tensor:
hidden_states = general_mm_embed_routine(
input_ids=input_ids,
forward_batch=forward_batch,
language_model=self.language_model,
audio_data_embedding_func=self.get_audio_feature,
positions=positions,
)

return hidden_states

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters(remove_duplicate=False))

for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue

if self.config.text_config.tie_word_embeddings and "lm_head.weight" in name:
continue

for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name or "audio_tower" in name:
continue
name_tmp = name.replace(weight_name, param_name)

# Skip loading extra bias for GPTQ models.
if name_tmp.endswith(".bias") and name_tmp not in params_dict:
continue
param = params_dict[name_tmp]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
try:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
except KeyError:
print(params_dict.keys())
raise

weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)


EntryClass = Qwen2AudioForConditionalGeneration
Loading