Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,9 +342,9 @@ def forward(
and thinker_tts_embeds.shape[1] == 3
):
bos_eos_pad = thinker_tts_embeds.to(text_hidden_states.device).chunk(3, dim=1) # 3 * [1,1,H]
multimodal_outputs["tts_bos_embed"] = bos_eos_pad[0]
multimodal_outputs["tts_eos_embed"] = bos_eos_pad[1]
multimodal_outputs["tts_pad_embed"] = bos_eos_pad[2]
multimodal_outputs["tts_bos_embed"] = [bos_eos_pad[0]]
multimodal_outputs["tts_eos_embed"] = [bos_eos_pad[1]]
multimodal_outputs["tts_pad_embed"] = [bos_eos_pad[2]]
except Exception:
# Best-effort; absence will be handled by talker with fallbacks
pass
Expand Down
9 changes: 8 additions & 1 deletion vllm_omni/worker/gpu_ar_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,11 @@ def propose_draft_token_ids(sampled_token_ids):
# Case 1: tensor aligned on token dimension
if isinstance(v, torch.Tensor) and v.shape[0] == hidden_states_cpu.shape[0]:
mm_payload[k] = v.detach().to("cpu")[prev_logits_index : logits_index + 1].contiguous()
elif isinstance(v, torch.Tensor) and v.shape[0] != hidden_states_cpu.shape[0]:
logger.error(
f"Error in merge multimodal outputs: Tensor dimension mismatch, \
{v.shape} != {hidden_states_cpu.shape} for {k}"
)
# Case 2: nested dict of tensors aligned on token dimension (e.g., selected_hidden_layers)
elif isinstance(v, dict):
sub_dict: dict[str, torch.Tensor] = {}
Expand All @@ -302,7 +307,9 @@ def propose_draft_token_ids(sampled_token_ids):
if sub_dict:
mm_payload[k] = sub_dict
elif isinstance(v, list):
element: torch.Tensor = v[0]
element = v[0]
if isinstance(element, torch.Tensor):
element = element.detach().to("cpu").contiguous()
multimodal_outputs[k] = v[1:] if len(v) > 1 else v
mm_payload[k] = element
except Exception as e:
Expand Down