1515"""
1616
1717import os
18+ import queue
1819import random
1920import time
21+ from threading import Thread
2022from typing import List , Optional
2123
2224import numpy as np
2325import paddle
26+ import zmq
2427from paddle import nn
2528
2629from fastdeploy import envs
2730from fastdeploy .config import FDConfig
2831from fastdeploy .engine .request import Request , RequestType
2932from fastdeploy .input .ernie4_5_vl_processor import DataProcessor
30- from fastdeploy .inter_communicator import IPCSignal
33+ from fastdeploy .inter_communicator import IPCSignal , ZmqIpcClient
3134from fastdeploy .model_executor .forward_meta import ForwardMeta
3235from fastdeploy .model_executor .graph_optimization .utils import (
3336 profile_run_guard ,
5962from fastdeploy .spec_decode import MTPProposer
6063from fastdeploy .utils import get_logger
6164from fastdeploy .worker .model_runner_base import ModelRunnerBase
62- from fastdeploy .worker .output import ModelOutputData , ModelRunnerOutput
65+ from fastdeploy .worker .output import LogprobsTensors , ModelOutputData , ModelRunnerOutput
6366
6467logger = get_logger ("xpu_model_runner" , "xpu_model_runner.log" )
6568
@@ -156,6 +159,106 @@ def __init__(
156159
157160 self .pd_disaggregation_mode : str = self .fd_config .parallel_config .pd_disaggregation_mode
158161
162+ # Initialize ZMQ client for async output
163+ self .zmq_client = None
164+ self .async_output_queue = None
165+ if envs .FD_USE_GET_SAVE_OUTPUT_V1 :
166+ logger .info (f"zmq client get_save_output_rank{ local_rank } " )
167+ self .zmq_client = ZmqIpcClient (name = f"get_save_output_rank{ local_rank } " , mode = zmq .PUSH )
168+ self .zmq_client .connect ()
169+ self .zmq_client .socket .SNDTIMEO = 3000
170+ self .async_output_queue : queue .Queue = queue .Queue ()
171+ self .async_output_copy_thread = Thread (
172+ target = self ._async_output_busy_loop ,
173+ daemon = True ,
174+ name = "WorkerAsyncOutputCopy" ,
175+ )
176+ self .async_output_copy_thread .start ()
177+ # prompt logprobs state
178+ self .prompt_logprobs_reqs : dict [str , Request ] = {}
179+ self .in_progress_prompt_logprobs : dict [str , LogprobsTensors ] = {}
180+
181+ def _async_output_busy_loop (self ):
182+ """Entrypoint for the thread which handles outputs asynchronously."""
183+ while True :
184+ try :
185+ if self .async_output_queue is None or self .zmq_client is None :
186+ break
187+ output = self .async_output_queue .get ()
188+ if self .zmq_client is not None :
189+ self .zmq_client .send_pyobj (output )
190+ except Exception as e :
191+ logger .exception ("Exception in async output loop: %s" , e )
192+
193+ def _get_prompt_logprobs_list (self , hidden_states : paddle .Tensor ) -> list [Optional [LogprobsTensors ]]:
194+ """
195+ Build prompt_logprobs for requests that asked for it.
196+ """
197+ if len (self .prompt_logprobs_reqs ) > 0 :
198+ assert (
199+ not self .fd_config .cache_config .enable_prefix_caching
200+ ), "prompt_logprobs must disable prefix caching, --no-enable-prefix-caching."
201+
202+ if len (self .prompt_logprobs_reqs ) == 0 :
203+ return self .scheduler_config .max_num_seqs * [None ]
204+
205+ logprobs_mode = self .fd_config .model_config .logprobs_mode
206+ prompt_logprobs_list : list [Optional [LogprobsTensors ]] = self .scheduler_config .max_num_seqs * [None ]
207+ completed_prefill_reqs : list [Request ] = []
208+
209+ for req_id , request in self .prompt_logprobs_reqs .items ():
210+ if not hasattr (request , "sampling_params" ) or request .sampling_params is None :
211+ continue
212+ num_prompt_logprobs = request .sampling_params .prompt_logprobs
213+ if request .prompt_token_ids is None or num_prompt_logprobs is None :
214+ continue
215+ if num_prompt_logprobs == - 1 :
216+ num_prompt_logprobs = self .ori_vocab_size
217+
218+ num_tokens = request .prefill_end_index - request .prefill_start_index
219+ num_prompt_tokens = len (request .prompt_token_ids )
220+
221+ logprobs_tensors = self .in_progress_prompt_logprobs .get (req_id )
222+ if not logprobs_tensors :
223+ logprobs_tensors = LogprobsTensors .empty_cpu (num_prompt_tokens - 1 , num_prompt_logprobs + 1 )
224+ self .in_progress_prompt_logprobs [req_id ] = logprobs_tensors
225+
226+ start_idx = request .prefill_start_index
227+ start_tok = start_idx + 1
228+ num_remaining_tokens = num_prompt_tokens - start_tok
229+ if num_tokens <= num_remaining_tokens :
230+ num_logits = num_tokens
231+ else :
232+ num_logits = num_remaining_tokens
233+ completed_prefill_reqs .append (request )
234+ prompt_logprobs_list [request .idx ] = logprobs_tensors
235+ if num_logits <= 0 :
236+ continue
237+
238+ offset = self .share_inputs ["cu_seqlens_q" ][request .idx ]
239+ prompt_hidden_states = hidden_states [offset : offset + num_logits ]
240+ logits = self .model .compute_logits (prompt_hidden_states )
241+ prompt_token_ids = request .prompt_token_ids [start_tok : start_tok + num_logits ]
242+ prompt_token_ids_tensor = paddle .to_tensor (prompt_token_ids , dtype = "int64" )
243+ if logprobs_mode == "raw_logprobs" :
244+ raw_logprobs = self .sampler .compute_logprobs (logits )
245+ elif logprobs_mode == "raw_logits" :
246+ raw_logprobs = logits
247+ else :
248+ raw_logprobs = self .sampler .compute_logprobs (logits )
249+ token_ids , logprobs , ranks = self .sampler .gather_logprobs (
250+ raw_logprobs , num_prompt_logprobs , prompt_token_ids_tensor
251+ )
252+ chunk_slice = slice (start_idx , start_idx + num_logits )
253+ logprobs_tensors .logprob_token_ids [chunk_slice ].copy_ (token_ids , False )
254+ logprobs_tensors .logprobs [chunk_slice ].copy_ (logprobs , False )
255+ logprobs_tensors .selected_token_ranks [chunk_slice ].copy_ (ranks , False )
256+
257+ for req in completed_prefill_reqs :
258+ del self .prompt_logprobs_reqs [req .request_id ]
259+ del self .in_progress_prompt_logprobs [req .request_id ]
260+ return prompt_logprobs_list
261+
159262 def exist_prefill (self ):
160263 """
161264 check whether prefill stage exist
@@ -405,6 +508,13 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int):
405508 self .share_inputs ["max_think_lens" ][idx : idx + 1 , :] = - 1
406509 self .share_inputs ["limit_think_status" ][idx : idx + 1 , :] = 0
407510
511+ if (
512+ hasattr (request , "sampling_params" )
513+ and request .sampling_params is not None
514+ and request .sampling_params .prompt_logprobs is not None
515+ ):
516+ self .prompt_logprobs_reqs [request .request_id ] = request
517+
408518 if len (request .output_token_ids ) == 0 :
409519 input_ids = request .prompt_token_ids
410520 else :
@@ -1296,6 +1406,10 @@ class at the server level, which is too granular for ModelRunner.
12961406 # 5. Speculative decode
12971407
12981408 # 6. Post Process
1409+ prompt_logprobs_list = None
1410+ if not self .speculative_decoding :
1411+ prompt_logprobs_list = self ._get_prompt_logprobs_list (model_output )
1412+
12991413 model_output_data = ModelOutputData (
13001414 next_tokens = self .share_inputs ["next_tokens" ],
13011415 stop_flags = self .share_inputs ["stop_flags" ],
@@ -1323,6 +1437,7 @@ class at the server level, which is too granular for ModelRunner.
13231437 accept_num = (self .share_inputs ["accept_num" ] if self .speculative_decoding else None ),
13241438 stop_token_ids = self .share_inputs ["stop_seqs" ],
13251439 stop_seqs_len = self .share_inputs ["stop_seqs_len" ],
1440+ prompt_logprobs_list = prompt_logprobs_list ,
13261441 )
13271442 if self .speculative_decoding :
13281443 # base model post process
@@ -1334,6 +1449,7 @@ class at the server level, which is too granular for ModelRunner.
13341449 share_inputs = self .share_inputs ,
13351450 block_size = self .cache_config .block_size ,
13361451 skip_save_output = is_dummy_run ,
1452+ async_output_queue = self .async_output_queue ,
13371453 think_end_id = self .model_config .think_end_id ,
13381454 line_break_id = self .model_config .line_break_id ,
13391455 )
0 commit comments