Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 20 additions & 11 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3658,7 +3658,24 @@ def _patch_prev_output(self):
return
model_input = self.cached_step_inputs.pop(0)
model_output = self.cached_step_outputs.pop(0)
delayed_tokens = model_output.token_ids.cpu().squeeze(-1).tolist()

assert model_output.sampling_metadata is not None, \
'Sampling metadata is required to patch the output!'
seq_groups = model_output.sampling_metadata.seq_groups
logprobs_required = any(seq_group.sampling_params.logprobs is not None
for seq_group in seq_groups)
prompt_logprobs_required = any(
seq_group.sampling_params.prompt_logprobs is not None
for seq_group in seq_groups)

if model_output.is_prompt and prompt_logprobs_required:
sample_idx_tensor = torch.tensor(
[sdx for sg in seq_groups for sdx in sg.sample_indices])

sampled_tokens = model_output.token_ids[sample_idx_tensor, :]
delayed_tokens = sampled_tokens.cpu().squeeze(-1).tolist()
else:
delayed_tokens = model_output.token_ids.cpu().squeeze(-1).tolist()

ctx = model_input.async_callback.keywords["ctx"] # type: ignore
# If there's no output to patch with, which is usually the case when
Expand All @@ -3682,21 +3699,13 @@ def _patch_prev_output(self):

delayed_logprobs = None
delayed_prompt_logprobs = None
assert model_output.sampling_metadata is not None, \
'Sampling metadata is required to patch the output!'
logprobs_required = any(
seq_group.sampling_params.logprobs is not None
for seq_group in model_output.sampling_metadata.seq_groups)
prompt_logprobs_required = any(
seq_group.sampling_params.prompt_logprobs is not None
for seq_group in model_output.sampling_metadata.seq_groups)
if logprobs_required or prompt_logprobs_required:
# We are one step ahead, so prompt is already marked as a computed.
# We need to reset the computed tokens count to 0,
# so that we can recompute the prompt logprobs.
computed_tokens = []
if model_output.is_prompt:
for seq_group in model_output.sampling_metadata.seq_groups:
for seq_group in seq_groups:
seq_ids = seq_group.seq_ids
assert len(seq_ids) == 1 # prompt has only 1 seq id.
seq_data = seq_group.seq_data[seq_ids[0]]
Expand All @@ -3710,7 +3719,7 @@ def _patch_prev_output(self):

# Reset the computed tokens count to the original value.
if model_output.is_prompt:
for seq_group in model_output.sampling_metadata.seq_groups:
for seq_group in seq_groups:
seq_ids = seq_group.seq_ids
seq_data = seq_group.seq_data[seq_ids[0]]
seq_data.update_num_computed_tokens(computed_tokens.pop(0))
Expand Down