Hi all,
Thank you for your great work.
Do you have any recommendations on how to handle the case for multi-query single-key-value attention which was used in PaLM?
For example with the Triton Flash Attention Function:
"""
q: (batch_size, heads, seq_len_q, dim)
k, v: (batch_size, heads, seq_len_kv, dim)
"""
q = torch.randn(1, 8, 512, 64)
k = torch.randn(1, 512, 64)
v = torch.randn(1, 512, 64)
out = flash_attn_func(q, k, v, causal = True) # (1, 8, 512, 64)
I greatly appreciate any input you could provide.
Thank you,
Enrico