Skip to content

Commit df19910

Browse files
MichelleWu351tonyluj
authored andcommitted
[Ascend] fix AscendAttnMaskBuilder bug to support float16 models (sgl-project#14271)
1 parent 0f0edd9 commit df19910

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

python/sglang/srt/layers/attention/ascend_backend.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,11 @@ def generate_attn_mask(max_seq_len, mode, dtype=torch.float16):
105105
)
106106
else:
107107
mask_value = torch.finfo(torch.float32).min if dtype == torch.float16 else 1
108-
attn_mask = torch.zeros(
109-
size=(max_seq_len, max_seq_len), dtype=dtype
110-
).masked_fill_(mask_flag, mask_value)
108+
attn_mask = (
109+
torch.zeros(size=(max_seq_len, max_seq_len))
110+
.masked_fill_(mask_flag, mask_value)
111+
.to(dtype)
112+
)
111113
return attn_mask
112114

113115
@staticmethod

0 commit comments

Comments
 (0)