@@ -230,6 +230,174 @@ struct channelwise_8bit_a_channelwise_8bit_b_q_at_k_attention_test_case {
230230 }
231231};
232232
233+ struct fp32_a_channelwise_8bit_b_attn_scores_at_v_test_case {
234+ int b;
235+ int s_attn;
236+ int s_v;
237+ int h;
238+ int d;
239+ size_t b_attn_stride;
240+ size_t h_attn_stride;
241+ size_t s_attn_stride;
242+ size_t b_v_stride;
243+ size_t h_v_stride;
244+ size_t s_v_stride;
245+ size_t b_v_qparams_stride;
246+ size_t h_v_qparams_stride;
247+ size_t s_v_qparams_stride;
248+
249+ std::vector<float > expected_output;
250+
251+ std::vector<float > attn_scores;
252+
253+ std::vector<float > v;
254+ std::vector<int8_t > v_qvals;
255+ std::vector<float > v_scales;
256+ std::vector<int8_t > v_zeros;
257+
258+ fp32_a_channelwise_8bit_b_attn_scores_at_v_test_case (
259+ int b_,
260+ int s_attn_,
261+ int s_v_,
262+ int h_,
263+ int d_,
264+ size_t b_attn_stride_,
265+ size_t h_attn_stride_,
266+ size_t s_attn_stride_,
267+ size_t b_v_stride_,
268+ size_t h_v_stride_,
269+ size_t s_v_stride_,
270+ size_t b_v_qparams_stride_,
271+ size_t h_v_qparams_stride_,
272+ size_t s_v_qparams_stride_,
273+ std::vector<float > expected_output_,
274+ std::vector<float > attn_scores_,
275+ std::vector<float > v_,
276+ std::vector<int8_t > v_qvals_,
277+ std::vector<float > v_scales_,
278+ std::vector<int8_t > v_zeros_)
279+ : b(b_),
280+ s_attn (s_attn_),
281+ s_v(s_v_),
282+ h(h_),
283+ d(d_),
284+ b_attn_stride(b_attn_stride_),
285+ h_attn_stride(h_attn_stride_),
286+ s_attn_stride(s_attn_stride_),
287+ b_v_stride(b_v_stride_),
288+ h_v_stride(h_v_stride_),
289+ s_v_stride(s_v_stride_),
290+ b_v_qparams_stride(b_v_qparams_stride_),
291+ h_v_qparams_stride(h_v_qparams_stride_),
292+ s_v_qparams_stride(s_v_qparams_stride_),
293+ expected_output(expected_output_),
294+ attn_scores(attn_scores_),
295+ v(v_),
296+ v_qvals(v_qvals_),
297+ v_scales(v_scales_),
298+ v_zeros(v_zeros_) {
299+ assert (expected_output.size () == b * s_attn * h * d);
300+ assert (attn_scores.size () == b * h * s_attn * s_v);
301+ assert (v.size () == b * h * s_v * d);
302+ assert (v_qvals.size () == b * h * s_v * d);
303+ assert (v_scales.size () == b * h * s_v);
304+ assert (v_zeros.size () == b * h * s_v);
305+ }
306+
307+ static fp32_a_channelwise_8bit_b_attn_scores_at_v_test_case
308+ generate (int b, int s_attn, int s_v, int h, int d, bool transposed_v = true ) {
309+ // Generate activations
310+ auto lhs = get_random_vector (b * h * s_attn * s_v, -1.0 , 1.0 );
311+
312+ auto [rhs, rhs_qvals, rhs_scales, rhs_zeros] =
313+ torchao::test_utils::generate_per_token_quantized_tensor (
314+ b * h * s_v, d);
315+ // Above function produces nxk matrix and to produce kxn you need transposed
316+ // = true. we do !rhs_is_transposed becaues when rhs_is_transposed = true
317+ // the shape should be nxk instead of kxn.
318+
319+ size_t b_attn_stride = h * s_attn * s_v;
320+ size_t h_attn_stride = s_attn * s_v;
321+ size_t s_attn_stride = s_v;
322+
323+ size_t b_v_stride = h * s_v * d;
324+ size_t h_v_stride = s_v * d;
325+ size_t s_v_stride = d;
326+
327+ size_t b_v_qparams_stride = h * s_v;
328+ size_t h_v_qparams_stride = s_v;
329+ size_t s_v_qparams_stride = 1 ;
330+
331+ if (!transposed_v) {
332+ h_v_stride = d;
333+ s_v_stride = h * d;
334+
335+ s_v_qparams_stride = h;
336+ h_v_qparams_stride = 1 ;
337+ }
338+
339+ // Compute expected output
340+ // Note that while the inputs can be in shape b x h x s_attn x s_v,
341+ // and b x h x s_v x d the output is not in b x h x s_attn x s_v
342+ // but rather b x s_attn x h x d. This is because the output of
343+ // SDPA will normally be in b x h x s_attn x d, but we want to
344+ // avoid any tranposes. Thus just aim to output in b x s_attn x h x d
345+ // This is just for testing purposes. Kernel can actually write output
346+ // in [B, H, S, D] if needed.
347+ std::vector<float > expected_output (b * s_attn * h * d);
348+ size_t b_out_stride = s_attn * h * d;
349+ size_t s_attn_out_stride = h * d;
350+ size_t h_out_stride = d;
351+
352+ for (int b_idx = 0 ; b_idx < b; b_idx++) {
353+ for (int s_attn_idx = 0 ; s_attn_idx < s_attn; s_attn_idx++) {
354+ for (int h_idx = 0 ; h_idx < h; h_idx++) {
355+ for (int d_idx = 0 ; d_idx < d; d_idx++) {
356+ float res = 0.0 ;
357+ for (int s_v_idx = 0 ; s_v_idx < s_v; s_v_idx++) {
358+ int lhs_idx = b_idx * b_attn_stride + s_attn_idx * s_attn_stride +
359+ h_idx * h_attn_stride + s_v_idx;
360+ int rhs_idx = b_idx * b_v_stride + h_idx * h_v_stride + d_idx +
361+ s_v_idx * s_v_stride;
362+ int rhs_scales_zp_idx = b_idx * b_v_qparams_stride +
363+ h_idx * h_v_qparams_stride + s_v_idx * s_v_qparams_stride;
364+ float rhs_dequant = rhs_scales[rhs_scales_zp_idx] *
365+ (rhs_qvals[rhs_idx] - rhs_zeros[rhs_scales_zp_idx]);
366+
367+ res += lhs[lhs_idx] * rhs_dequant;
368+ }
369+ expected_output
370+ [b_idx * b_out_stride + s_attn_idx * s_attn_out_stride +
371+ h_idx * h_out_stride + d_idx] = res;
372+ }
373+ }
374+ }
375+ }
376+
377+ // Return test case
378+ return fp32_a_channelwise_8bit_b_attn_scores_at_v_test_case (
379+ b,
380+ s_attn,
381+ s_v,
382+ h,
383+ d,
384+ b_attn_stride,
385+ h_attn_stride,
386+ s_attn_stride,
387+ b_v_stride,
388+ h_v_stride,
389+ s_v_stride,
390+ b_v_qparams_stride,
391+ h_v_qparams_stride,
392+ s_v_qparams_stride,
393+ expected_output,
394+ lhs,
395+ rhs,
396+ rhs_qvals,
397+ rhs_scales,
398+ rhs_zeros);
399+ }
400+ };
233401} // namespace torchao
234402
235403#endif // defined(__aarch64__) || defined(__ARM_NEON)
0 commit comments