Skip to content

Commit 2c9788f

Browse files
authored
fix typo in flops calculation for local attention (Dao-AILab#1883)
1 parent 1a8d8a4 commit 2c9788f

2 files changed

Lines changed: 2 additions & 2 deletions

File tree

benchmarks/benchmark_attn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def flops(batch, nheads, seqlen_q, seqlen_k, headdim, headdim_v, causal=False, w
7070
else:
7171
row_idx = torch.arange(seqlen_q, device='cuda')
7272
col_left = torch.maximum(row_idx + seqlen_k - seqlen_q - window_size[0], torch.tensor(0)) if window_size[0] is not None else torch.zeros_like(row_idx)
73-
col_right = torch.minimum(row_idx + seqlen_k - seqlen_q - window_size[1], torch.tensor(seqlen_k - 1)) if window_size[1] is not None else torch.full_like(row_idx, seqlen_k - 1)
73+
col_right = torch.minimum(row_idx + seqlen_k - seqlen_q + window_size[1], torch.tensor(seqlen_k - 1)) if window_size[1] is not None else torch.full_like(row_idx, seqlen_k - 1)
7474
avg_seqlen = (col_right - col_left + 1).float().mean().item()
7575
return batch * nheads * 2 * seqlen_q * avg_seqlen * (headdim + headdim_v)
7676

hopper/benchmark_attn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def flops(batch, nheads, seqlen_q, seqlen_k, headdim, headdim_v, causal=False, w
6868
else:
6969
row_idx = torch.arange(seqlen_q, device='cuda')
7070
col_left = torch.maximum(row_idx + seqlen_k - seqlen_q - window_size[0], torch.tensor(0))
71-
col_right = torch.minimum(row_idx + seqlen_k - seqlen_q - window_size[1], torch.tensor(seqlen_k - 1))
71+
col_right = torch.minimum(row_idx + seqlen_k - seqlen_q + window_size[1], torch.tensor(seqlen_k - 1))
7272
avg_seqlen = (col_right - col_left + 1).float().mean().item()
7373
return batch * nheads * 2 * seqlen_q * avg_seqlen * (headdim + headdim_v)
7474

0 commit comments

Comments
 (0)