Skip to content

Commit 3002334

Browse files
iosmersqw86972190root
authored
[Cherry-Pick] [XPU]Cherry-pick Support ZMQ logprobs(#5628) (#5852)
* update * delete min_tokens --------- Co-authored-by: qw86972190 <127910106+qw86972190@users.noreply.github.com> Co-authored-by: root <root@gajl-bbc-onlinec-com-1498355.gajl.baidu.com>
1 parent 44e44ab commit 3002334

File tree

2 files changed

+191
-23
lines changed

2 files changed

+191
-23
lines changed

fastdeploy/model_executor/xpu_pre_and_post_process.py

Lines changed: 73 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,18 @@
1414
# limitations under the License.
1515
"""
1616

17-
from typing import Dict, Optional
17+
import queue
18+
from typing import Dict, List, Optional
1819

20+
import numpy as np
1921
import paddle
2022

2123
from fastdeploy import envs
2224
from fastdeploy.model_executor.forward_meta import XPUForwardMeta
2325
from fastdeploy.model_executor.layers.sample.sampler import Sampler
26+
from fastdeploy.output.stream_transfer_data import DecoderState, StreamTransferData
2427
from fastdeploy.platforms import current_platform
25-
from fastdeploy.worker.output import ModelOutputData
28+
from fastdeploy.worker.output import LogprobsTensors, ModelOutputData
2629

2730
if current_platform.is_xpu():
2831
from fastdeploy.model_executor.ops.xpu import (
@@ -49,6 +52,43 @@
4952
)
5053

5154

55+
def _build_stream_transfer_data(
56+
output_tokens: paddle.Tensor,
57+
pooler_outputs: List = None,
58+
logprobs: Optional[LogprobsTensors] = None,
59+
prompt_logprobs_list: Optional[LogprobsTensors] = None,
60+
):
61+
"""Split output_tokens and output"""
62+
stream_transfer_datas = []
63+
if output_tokens is not None:
64+
output_tokens = output_tokens.reshape([-1]).numpy()
65+
output_tokens_lists = np.split(output_tokens, output_tokens.shape[0])
66+
67+
for bid, output_token_per_sample in enumerate(output_tokens_lists):
68+
stream_transfer_data = StreamTransferData(
69+
decoder_state=DecoderState.TEXT, tokens=output_token_per_sample, batch_id=bid
70+
)
71+
if logprobs:
72+
stream_transfer_data.logprobs = logprobs.slice_rows(bid, bid + 1)
73+
if prompt_logprobs_list:
74+
stream_transfer_data.prompt_logprobs = prompt_logprobs_list[bid]
75+
stream_transfer_datas.append(stream_transfer_data)
76+
elif pooler_outputs is not None:
77+
for bid, pooler_output in enumerate(pooler_outputs):
78+
if pooler_output is None:
79+
continue
80+
if pooler_output.dtype == paddle.bfloat16:
81+
pooler_output = pooler_output.astype("float32")
82+
83+
pooler_output = pooler_output.numpy()
84+
85+
stream_transfer_data = StreamTransferData(
86+
decoder_state=DecoderState.TEXT, pooler_output=pooler_output, batch_id=bid
87+
)
88+
stream_transfer_datas.append(stream_transfer_data)
89+
return stream_transfer_datas
90+
91+
5292
def xpu_pre_process(
5393
input_ids: paddle.Tensor,
5494
seq_lens_this_time: int,
@@ -217,6 +257,8 @@ def xpu_post_process_normal(
217257
share_inputs: Dict[str, paddle.Tensor],
218258
block_size: int = 64,
219259
skip_save_output: bool = False,
260+
save_each_rank: bool = False,
261+
async_output_queue: queue.Queue = None,
220262
think_end_id: int = None,
221263
line_break_id: int = None,
222264
) -> None:
@@ -314,27 +356,37 @@ def xpu_post_process_normal(
314356
# 3. Transmit the model's output and stop generation signal via message queue.
315357
# In the future, we will abandon this approach.
316358
if not skip_save_output:
317-
if sampler_output.logprobs_tensors is None:
318-
save_output(
319-
sampled_token_ids,
320-
model_output.not_need_stop,
321-
model_output.mp_rank,
322-
False, # use_ep
323-
)
359+
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
360+
if save_each_rank or model_output.mp_rank == 0:
361+
output = _build_stream_transfer_data(
362+
sampled_token_ids,
363+
logprobs=sampler_output.logprobs_tensors,
364+
prompt_logprobs_list=model_output.prompt_logprobs_list,
365+
)
366+
if async_output_queue is not None:
367+
async_output_queue.put(output)
324368
else:
325-
if save_output_topk is None:
326-
raise ImportError(
327-
"save_output_topk operator is not available. "
328-
"Please rebuild the XPU operators with the new get_output_msg_with_topk.cc and save_output_msg_with_topk.cc files."
369+
if sampler_output.logprobs_tensors is None:
370+
save_output(
371+
sampled_token_ids,
372+
model_output.not_need_stop,
373+
model_output.mp_rank,
374+
False, # use_ep
375+
)
376+
else:
377+
if save_output_topk is None:
378+
raise ImportError(
379+
"save_output_topk operator is not available. "
380+
"Please rebuild the XPU operators with the new get_output_msg_with_topk.cc and save_output_msg_with_topk.cc files."
381+
)
382+
save_output_topk(
383+
sampled_token_ids,
384+
sampler_output.logprobs_tensors.logprob_token_ids,
385+
sampler_output.logprobs_tensors.logprobs,
386+
sampler_output.logprobs_tensors.selected_token_ranks,
387+
model_output.not_need_stop,
388+
model_output.mp_rank,
329389
)
330-
save_output_topk(
331-
sampled_token_ids,
332-
sampler_output.logprobs_tensors.logprob_token_ids,
333-
sampler_output.logprobs_tensors.logprobs,
334-
sampler_output.logprobs_tensors.selected_token_ranks,
335-
model_output.not_need_stop,
336-
model_output.mp_rank,
337-
)
338390

339391

340392
def xpu_post_process_specualate(

fastdeploy/worker/xpu_model_runner.py

Lines changed: 118 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,22 @@
1515
"""
1616

1717
import os
18+
import queue
1819
import random
1920
import time
21+
from threading import Thread
2022
from typing import List, Optional
2123

2224
import numpy as np
2325
import paddle
26+
import zmq
2427
from paddle import nn
2528

2629
from fastdeploy import envs
2730
from fastdeploy.config import FDConfig
2831
from fastdeploy.engine.request import Request, RequestType
2932
from fastdeploy.input.ernie4_5_vl_processor import DataProcessor
30-
from fastdeploy.inter_communicator import IPCSignal
33+
from fastdeploy.inter_communicator import IPCSignal, ZmqIpcClient
3134
from fastdeploy.model_executor.forward_meta import ForwardMeta
3235
from fastdeploy.model_executor.graph_optimization.utils import (
3336
profile_run_guard,
@@ -59,7 +62,7 @@
5962
from fastdeploy.spec_decode import MTPProposer
6063
from fastdeploy.utils import get_logger
6164
from fastdeploy.worker.model_runner_base import ModelRunnerBase
62-
from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput
65+
from fastdeploy.worker.output import LogprobsTensors, ModelOutputData, ModelRunnerOutput
6366

6467
logger = get_logger("xpu_model_runner", "xpu_model_runner.log")
6568

@@ -156,6 +159,106 @@ def __init__(
156159

157160
self.pd_disaggregation_mode: str = self.fd_config.parallel_config.pd_disaggregation_mode
158161

162+
# Initialize ZMQ client for async output
163+
self.zmq_client = None
164+
self.async_output_queue = None
165+
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
166+
logger.info(f"zmq client get_save_output_rank{local_rank}")
167+
self.zmq_client = ZmqIpcClient(name=f"get_save_output_rank{local_rank}", mode=zmq.PUSH)
168+
self.zmq_client.connect()
169+
self.zmq_client.socket.SNDTIMEO = 3000
170+
self.async_output_queue: queue.Queue = queue.Queue()
171+
self.async_output_copy_thread = Thread(
172+
target=self._async_output_busy_loop,
173+
daemon=True,
174+
name="WorkerAsyncOutputCopy",
175+
)
176+
self.async_output_copy_thread.start()
177+
# prompt logprobs state
178+
self.prompt_logprobs_reqs: dict[str, Request] = {}
179+
self.in_progress_prompt_logprobs: dict[str, LogprobsTensors] = {}
180+
181+
def _async_output_busy_loop(self):
182+
"""Entrypoint for the thread which handles outputs asynchronously."""
183+
while True:
184+
try:
185+
if self.async_output_queue is None or self.zmq_client is None:
186+
break
187+
output = self.async_output_queue.get()
188+
if self.zmq_client is not None:
189+
self.zmq_client.send_pyobj(output)
190+
except Exception as e:
191+
logger.exception("Exception in async output loop: %s", e)
192+
193+
def _get_prompt_logprobs_list(self, hidden_states: paddle.Tensor) -> list[Optional[LogprobsTensors]]:
194+
"""
195+
Build prompt_logprobs for requests that asked for it.
196+
"""
197+
if len(self.prompt_logprobs_reqs) > 0:
198+
assert (
199+
not self.fd_config.cache_config.enable_prefix_caching
200+
), "prompt_logprobs must disable prefix caching, --no-enable-prefix-caching."
201+
202+
if len(self.prompt_logprobs_reqs) == 0:
203+
return self.scheduler_config.max_num_seqs * [None]
204+
205+
logprobs_mode = self.fd_config.model_config.logprobs_mode
206+
prompt_logprobs_list: list[Optional[LogprobsTensors]] = self.scheduler_config.max_num_seqs * [None]
207+
completed_prefill_reqs: list[Request] = []
208+
209+
for req_id, request in self.prompt_logprobs_reqs.items():
210+
if not hasattr(request, "sampling_params") or request.sampling_params is None:
211+
continue
212+
num_prompt_logprobs = request.sampling_params.prompt_logprobs
213+
if request.prompt_token_ids is None or num_prompt_logprobs is None:
214+
continue
215+
if num_prompt_logprobs == -1:
216+
num_prompt_logprobs = self.ori_vocab_size
217+
218+
num_tokens = request.prefill_end_index - request.prefill_start_index
219+
num_prompt_tokens = len(request.prompt_token_ids)
220+
221+
logprobs_tensors = self.in_progress_prompt_logprobs.get(req_id)
222+
if not logprobs_tensors:
223+
logprobs_tensors = LogprobsTensors.empty_cpu(num_prompt_tokens - 1, num_prompt_logprobs + 1)
224+
self.in_progress_prompt_logprobs[req_id] = logprobs_tensors
225+
226+
start_idx = request.prefill_start_index
227+
start_tok = start_idx + 1
228+
num_remaining_tokens = num_prompt_tokens - start_tok
229+
if num_tokens <= num_remaining_tokens:
230+
num_logits = num_tokens
231+
else:
232+
num_logits = num_remaining_tokens
233+
completed_prefill_reqs.append(request)
234+
prompt_logprobs_list[request.idx] = logprobs_tensors
235+
if num_logits <= 0:
236+
continue
237+
238+
offset = self.share_inputs["cu_seqlens_q"][request.idx]
239+
prompt_hidden_states = hidden_states[offset : offset + num_logits]
240+
logits = self.model.compute_logits(prompt_hidden_states)
241+
prompt_token_ids = request.prompt_token_ids[start_tok : start_tok + num_logits]
242+
prompt_token_ids_tensor = paddle.to_tensor(prompt_token_ids, dtype="int64")
243+
if logprobs_mode == "raw_logprobs":
244+
raw_logprobs = self.sampler.compute_logprobs(logits)
245+
elif logprobs_mode == "raw_logits":
246+
raw_logprobs = logits
247+
else:
248+
raw_logprobs = self.sampler.compute_logprobs(logits)
249+
token_ids, logprobs, ranks = self.sampler.gather_logprobs(
250+
raw_logprobs, num_prompt_logprobs, prompt_token_ids_tensor
251+
)
252+
chunk_slice = slice(start_idx, start_idx + num_logits)
253+
logprobs_tensors.logprob_token_ids[chunk_slice].copy_(token_ids, False)
254+
logprobs_tensors.logprobs[chunk_slice].copy_(logprobs, False)
255+
logprobs_tensors.selected_token_ranks[chunk_slice].copy_(ranks, False)
256+
257+
for req in completed_prefill_reqs:
258+
del self.prompt_logprobs_reqs[req.request_id]
259+
del self.in_progress_prompt_logprobs[req.request_id]
260+
return prompt_logprobs_list
261+
159262
def exist_prefill(self):
160263
"""
161264
check whether prefill stage exist
@@ -405,6 +508,13 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int):
405508
self.share_inputs["max_think_lens"][idx : idx + 1, :] = -1
406509
self.share_inputs["limit_think_status"][idx : idx + 1, :] = 0
407510

511+
if (
512+
hasattr(request, "sampling_params")
513+
and request.sampling_params is not None
514+
and request.sampling_params.prompt_logprobs is not None
515+
):
516+
self.prompt_logprobs_reqs[request.request_id] = request
517+
408518
if len(request.output_token_ids) == 0:
409519
input_ids = request.prompt_token_ids
410520
else:
@@ -1296,6 +1406,10 @@ class at the server level, which is too granular for ModelRunner.
12961406
# 5. Speculative decode
12971407

12981408
# 6. Post Process
1409+
prompt_logprobs_list = None
1410+
if not self.speculative_decoding:
1411+
prompt_logprobs_list = self._get_prompt_logprobs_list(model_output)
1412+
12991413
model_output_data = ModelOutputData(
13001414
next_tokens=self.share_inputs["next_tokens"],
13011415
stop_flags=self.share_inputs["stop_flags"],
@@ -1323,6 +1437,7 @@ class at the server level, which is too granular for ModelRunner.
13231437
accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None),
13241438
stop_token_ids=self.share_inputs["stop_seqs"],
13251439
stop_seqs_len=self.share_inputs["stop_seqs_len"],
1440+
prompt_logprobs_list=prompt_logprobs_list,
13261441
)
13271442
if self.speculative_decoding:
13281443
# base model post process
@@ -1334,6 +1449,7 @@ class at the server level, which is too granular for ModelRunner.
13341449
share_inputs=self.share_inputs,
13351450
block_size=self.cache_config.block_size,
13361451
skip_save_output=is_dummy_run,
1452+
async_output_queue=self.async_output_queue,
13371453
think_end_id=self.model_config.think_end_id,
13381454
line_break_id=self.model_config.line_break_id,
13391455
)

0 commit comments

Comments
 (0)