diff --git a/vllm/v1/worker/hpu_model_runner.py b/vllm/v1/worker/hpu_model_runner.py index a6b0f534bfe0..4a6bcb961bb8 100644 --- a/vllm/v1/worker/hpu_model_runner.py +++ b/vllm/v1/worker/hpu_model_runner.py @@ -129,6 +129,10 @@ class DecodeInputData: logits_indices: Optional[torch.Tensor] = None +class BucketingFailedException(Exception): + pass + + def bool_helper(value): value = value.lower() return value in ("y", "yes", "t", "true", "on", "1") @@ -966,6 +970,8 @@ def _bucketize_merged_prompt(self, seq_lens, num_blocks): def _bucketize_2d_prompt(self, seq_lens, num_blocks): bs = len(seq_lens) + if bs > self.max_prefill_batch_size: + raise BucketingFailedException seq = max(seq_lens) num_blocks = max(num_blocks) if len(num_blocks) > 0 else 0 bs, seq, num_blocks = self.bucketing_manager.find_prompt_bucket( @@ -981,8 +987,11 @@ def _get_prompt_bucketing_fn(self): def _can_merge_prefill_contents(self, lhs, rhs): combined_num_tokens = lhs.get_num_tokens() + rhs.get_num_tokens() bucketing_fn = self._get_prompt_bucketing_fn() - target_bs, target_seq, target_blocks = bucketing_fn( - combined_num_tokens, []) + try: + target_bs, target_seq, target_blocks = bucketing_fn( + combined_num_tokens, []) + except BucketingFailedException: + return False return target_bs <= self.max_prefill_batch_size and\ target_bs * target_seq <= self.max_num_tokens