Skip to content

IG: fix multimodal reshape for Qwen2.5-VL (revet #691)#1081

Merged
michalkuligowski merged 3 commits intoHabanaAI:habana_mainfrom
imangohari1:ig/habana_main_85c985e_withoutPR691
Apr 16, 2025
Merged

IG: fix multimodal reshape for Qwen2.5-VL (revet #691)#1081
michalkuligowski merged 3 commits intoHabanaAI:habana_mainfrom
imangohari1:ig/habana_main_85c985e_withoutPR691

Conversation

@imangohari1
Copy link
Copy Markdown

@imangohari1 imangohari1 commented Apr 14, 2025

This PR reverts #691 that leads to AttributeError: 'tuple' object has no attribute 'reshape' for Qwen2.5-VL.

Test

server:

python -m vllm.entrypoints.openai.api_server --port 8080 --model Qwen/Qwen2.5-VL-3B-Instruct --tensor-parallel-size 1 --max-num-seqs 128 --dtype bfloat16 --gpu-memory-util 0.9 --max-num-batched-tokens 32768 --max-model-len 32768 --block-size 128

Client:

python benchmark_serving.py --backend openai-chat --model Qwen/Qwen2.5-VL-3B-Instruct --trust-remote-code --port 8080 --endpoint /v1/chat/completions --dataset-path lmarena-ai/vision-arena-bench-v0.1 --dataset-name hf --hf-split train --num-prompts 40 --request-rate inf --seed 0 --ignore_eos

as is

ERROR 04-14 16:51:14 engine.py:139] AttributeError("'tuple' object has no attribute 'reshape'")
ERROR 04-14 16:51:14 engine.py:139] Traceback (most recent call last):
ERROR 04-14 16:51:14 engine.py:139]   File "/root/vllm-fork/vllm/engine/multiprocessing/engine.py", line 137, in start
ERROR 04-14 16:51:14 engine.py:139]     self.run_engine_loop()
ERROR 04-14 16:51:14 engine.py:139]   File "/root/vllm-fork/vllm/engine/multiprocessing/engine.py", line 200, in run_engine_loop
ERROR 04-14 16:51:14 engine.py:139]     request_outputs = self.engine_step()
ERROR 04-14 16:51:14 engine.py:139]   File "/root/vllm-fork/vllm/engine/multiprocessing/engine.py", line 218, in engine_step
ERROR 04-14 16:51:14 engine.py:139]     raise e
ERROR 04-14 16:51:14 engine.py:139]   File "/root/vllm-fork/vllm/engine/multiprocessing/engine.py", line 209, in engine_step
ERROR 04-14 16:51:14 engine.py:139]     return self.engine.step()
ERROR 04-14 16:51:14 engine.py:139]   File "/root/vllm-fork/vllm/engine/llm_engine.py", line 1380, in step
ERROR 04-14 16:51:14 engine.py:139]     outputs = self.model_executor.execute_model(
ERROR 04-14 16:51:14 engine.py:139]   File "/root/vllm-fork/vllm/executor/executor_base.py", line 138, in execute_model
ERROR 04-14 16:51:14 engine.py:139]     output = self.collective_rpc("execute_model",
ERROR 04-14 16:51:14 engine.py:139]   File "/root/vllm-fork/vllm/executor/uniproc_executor.py", line 58, in collective_rpc
ERROR 04-14 16:51:14 engine.py:139]     answer = run_method(self.driver_worker, method, args, kwargs)
ERROR 04-14 16:51:14 engine.py:139]   File "/root/vllm-fork/vllm/utils.py", line 2323, in run_method
ERROR 04-14 16:51:14 engine.py:139]     return func(*args, **kwargs)
ERROR 04-14 16:51:14 engine.py:139]   File "/root/vllm-fork/vllm/worker/hpu_worker.py", line 294, in execute_model
ERROR 04-14 16:51:14 engine.py:139]     output = LocalOrDistributedWorkerBase.execute_model(
ERROR 04-14 16:51:14 engine.py:139]   File "/root/vllm-fork/vllm/worker/worker_base.py", line 418, in execute_model
ERROR 04-14 16:51:14 engine.py:139]     output = self.model_runner.execute_model(
ERROR 04-14 16:51:14 engine.py:139]   File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
ERROR 04-14 16:51:14 engine.py:139]     return func(*args, **kwargs)
ERROR 04-14 16:51:14 engine.py:139]   File "/root/vllm-fork/vllm/worker/hpu_model_runner.py", line 2697, in execute_model
ERROR 04-14 16:51:14 engine.py:139]     hidden_states = self.model.forward(
ERROR 04-14 16:51:14 engine.py:139]   File "/usr/local/lib/python3.10/dist-packages/habana_frameworks/torch/hpu/graphs.py", line 745, in forward
ERROR 04-14 16:51:14 engine.py:139]     return wrapped_hpugraph_forward(
ERROR 04-14 16:51:14 engine.py:139]   File "/usr/local/lib/python3.10/dist-packages/habana_frameworks/torch/hpu/graphs.py", line 610, in wrapped_hpugraph_forward
ERROR 04-14 16:51:14 engine.py:139]     outputs = orig_fwd(*args, **kwargs)
ERROR 04-14 16:51:14 engine.py:139]   File "/root/vllm-fork/vllm/worker/hpu_model_runner.py", line 423, in forward
ERROR 04-14 16:51:14 engine.py:139]     hidden_states = self.model(*args, **kwargs)
ERROR 04-14 16:51:14 engine.py:139]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1742, in _wrapped_call_impl
ERROR 04-14 16:51:14 engine.py:139]     return self._call_impl(*args, **kwargs)
ERROR 04-14 16:51:14 engine.py:139]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1848, in _call_impl
ERROR 04-14 16:51:14 engine.py:139]     return inner()
ERROR 04-14 16:51:14 engine.py:139]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1796, in inner
ERROR 04-14 16:51:14 engine.py:139]     result = forward_call(*args, **kwargs)
ERROR 04-14 16:51:14 engine.py:139]   File "/root/vllm-fork/vllm/model_executor/models/qwen2_5_vl.py", line 1104, in forward
ERROR 04-14 16:51:14 engine.py:139]     inputs_embeds = self.get_input_embeddings_v0(
ERROR 04-14 16:51:14 engine.py:139]   File "/root/vllm-fork/vllm/model_executor/models/qwen2_5_vl.py", line 1037, in get_input_embeddings_v0
ERROR 04-14 16:51:14 engine.py:139]     inputs_embeds = merge_multimodal_embeddings(
ERROR 04-14 16:51:14 engine.py:139]   File "/root/vllm-fork/vllm/model_executor/models/utils.py", line 448, in merge_multimodal_embeddings
ERROR 04-14 16:51:14 engine.py:139]     return _hpu_merge_multimodal_embeddings(
ERROR 04-14 16:51:14 engine.py:139]   File "/root/vllm-fork/vllm/model_executor/models/utils.py", line 674, in _hpu_merge_multimodal_embeddings
ERROR 04-14 16:51:14 engine.py:139]     multimodal_embeddings = multimodal_embeddings.reshape(-1, hidden_size)

with this PR

100%|██████████| 1/1 [00:01<00:00,  1.14s/it]
============ Serving Benchmark Result ============
Successful requests:                     1         
Benchmark duration (s):                  1.14      
Total input tokens:                      52        
Total generated tokens:                  128       
Request throughput (req/s):              0.88      
Output token throughput (tok/s):         112.62    
Total Token throughput (tok/s):          158.37    
---------------Time to First Token----------------
Mean TTFT (ms):                          169.75    
Median TTFT (ms):                        169.75    
P99 TTFT (ms):                           169.75    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          7.61      
Median TPOT (ms):                        7.61      
P99 TPOT (ms):                           7.61      
---------------Inter-token Latency----------------
Mean ITL (ms):                           7.55      
Median ITL (ms):                         7.59      
P99 ITL (ms):                            8.10      
==================================================

@malkomes
Copy link
Copy Markdown

@michalkuligowski Running some basic unit tests for qwen2_5-vl fails right now.

Unit Tests for qwen2_5:

VLLM_SKIP_WARMUP=true pytest tests/models/decoder_only/vision_language/test_models.py -s -v -k "[qwen2_5"

I think the original PR was too modified and perhaps the final version is not necessary anymore? If we run the following command:

Example vision_language glm4v

python examples/offline_inference/vision_language.py -m glm4v

we get a higher speed input by reverting the changes.

Current main:

UT qwen2_5: 9 failed, 3 passed, 173 deselected, 4 warnings in 175.08s (0:02:55)
vision_language glm4v: [02:34<00:00, 38.71s/it, est. speed input: 41.75 toks/s, output: 1.65 toks/s]

Reverting #691

UT qwen2_5: 12 passed, 173 deselected, 40 warnings in 1203.63s (0:20:03)
vision_language glm4v: [01:31<00:00, 22.85s/it, est. speed input: 70.73 toks/s, output: 2.80 toks/s]

@imangohari1 imangohari1 marked this pull request as ready for review April 15, 2025 16:11
@imangohari1
Copy link
Copy Markdown
Author

@michalkuligowski @kzawora-intel
Could you please take a look at this PR. We believe this PR is needed for Qwen2-vl functionality without introducing perf. regression for glm4v model. Thanks.

@michalkuligowski
Copy link
Copy Markdown

/run-gaudi-tests

@michalkuligowski michalkuligowski merged commit 456caf7 into HabanaAI:habana_main Apr 16, 2025
43 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants