|
337 | 337 | " gpu_linear_kernel[(out_f,)](W_np.flatten().copy(), x.copy(), y, in_f)\n", |
338 | 338 | " return y\n", |
339 | 339 | "\n", |
340 | | - "# --- rmsnorm (4 kernel launches) --------------------------------------------\n", |
| 340 | + "# --- fused rmsnorm (1 kernel launch) ----------------------------------------\n", |
341 | 341 | "\n", |
342 | 342 | "@tt.jit\n", |
343 | | - "def rn_square(src, dst, N):\n", |
344 | | - " pid = tt.program_id(0)\n", |
345 | | - " off = pid * 64 + tt.arange(0, 64)\n", |
346 | | - " mask = off < N\n", |
347 | | - " x = tt.load(src + off, mask=mask)\n", |
348 | | - " tt.store(dst + off, x * x, mask=mask)\n", |
349 | | - "\n", |
350 | | - "@tt.jit\n", |
351 | | - "def rn_reduce_sum(src, dst, N):\n", |
352 | | - " pid = tt.program_id(0)\n", |
353 | | - " off = pid * 64 + tt.arange(0, 64)\n", |
354 | | - " mask = off < N\n", |
355 | | - " x = tt.load(src + off, mask=mask)\n", |
356 | | - " total = tt.reduce_sum(x)\n", |
357 | | - " tt.store(dst + pid, total)\n", |
358 | | - "\n", |
359 | | - "@tt.jit\n", |
360 | | - "def rn_rsqrt_mean(sum_ptr, n_ptr, out_ptr):\n", |
361 | | - " tid = tt.arange(0, 64)\n", |
362 | | - " s = tt.load(sum_ptr)\n", |
363 | | - " n = tt.load(n_ptr)\n", |
364 | | - " mean_eps = s / n + 1e-5\n", |
365 | | - " scale = tt.rsqrt(mean_eps)\n", |
366 | | - " tt.store(out_ptr, scale)\n", |
367 | | - "\n", |
368 | | - "@tt.jit\n", |
369 | | - "def rn_mul_scalar(src, scalar_ptr, dst, N):\n", |
370 | | - " pid = tt.program_id(0)\n", |
371 | | - " off = pid * 64 + tt.arange(0, 64)\n", |
372 | | - " mask = off < N\n", |
373 | | - " x = tt.load(src + off, mask=mask)\n", |
374 | | - " s = tt.load(scalar_ptr)\n", |
375 | | - " tt.store(dst + off, x * s, mask=mask)\n", |
| 343 | + "def fused_rmsnorm_kernel(src, dst, N, n_ptr):\n", |
| 344 | + " tid = tt.arange(0, 64)\n", |
| 345 | + " mask = tid < N\n", |
| 346 | + " x = tt.load(src + tid, mask=mask)\n", |
| 347 | + " sq = x * x\n", |
| 348 | + " s = tt.reduce_sum(sq)\n", |
| 349 | + " n = tt.load(n_ptr)\n", |
| 350 | + " scale = tt.rsqrt(s / n + 1e-5)\n", |
| 351 | + " tt.store(dst + tid, x * scale, mask=mask)\n", |
376 | 352 | "\n", |
377 | 353 | "def gpu_rmsnorm(x, N):\n", |
378 | | - " grid = (max(1, (N + 63) // 64),)\n", |
379 | | - " tmp_sq = np.zeros(N, dtype=np.float32)\n", |
380 | | - " tmp_sum = np.zeros(1, dtype=np.float32)\n", |
381 | | - " tmp_scl = np.zeros(1, dtype=np.float32)\n", |
382 | 354 | " n_arr = np.array([float(N)], dtype=np.float32)\n", |
383 | 355 | " out = np.zeros(N, dtype=np.float32)\n", |
384 | | - " rn_square[grid](x, tmp_sq, N)\n", |
385 | | - " rn_reduce_sum[(1,)](tmp_sq, tmp_sum, N)\n", |
386 | | - " rn_rsqrt_mean[(1,)](tmp_sum, n_arr, tmp_scl)\n", |
387 | | - " rn_mul_scalar[grid](x, tmp_scl, out, N)\n", |
| 356 | + " fused_rmsnorm_kernel[(1,)](x, out, N, n_arr)\n", |
388 | 357 | " return out\n", |
389 | 358 | "\n", |
390 | | - "# --- softmax (5 kernel launches) --------------------------------------------\n", |
| 359 | + "# --- fused softmax (1 kernel launch) ----------------------------------------\n", |
391 | 360 | "\n", |
392 | 361 | "@tt.jit\n", |
393 | | - "def sm_reduce_max(src, dst, N):\n", |
394 | | - " pid = tt.program_id(0)\n", |
395 | | - " off = pid * 64 + tt.arange(0, 64)\n", |
396 | | - " mask = off < N\n", |
397 | | - " x = tt.load(src + off, mask=mask)\n", |
398 | | - " mx = tt.reduce_max(x)\n", |
399 | | - " tt.store(dst + pid, mx)\n", |
400 | | - "\n", |
401 | | - "@tt.jit\n", |
402 | | - "def sm_sub_scalar(src, scalar_ptr, dst, N):\n", |
403 | | - " pid = tt.program_id(0)\n", |
404 | | - " off = pid * 64 + tt.arange(0, 64)\n", |
405 | | - " mask = off < N\n", |
406 | | - " x = tt.load(src + off, mask=mask)\n", |
407 | | - " s = tt.load(scalar_ptr)\n", |
408 | | - " tt.store(dst + off, x - s, mask=mask)\n", |
409 | | - "\n", |
410 | | - "@tt.jit\n", |
411 | | - "def sm_exp(src, dst, N):\n", |
412 | | - " pid = tt.program_id(0)\n", |
413 | | - " off = pid * 64 + tt.arange(0, 64)\n", |
414 | | - " mask = off < N\n", |
415 | | - " x = tt.load(src + off, mask=mask)\n", |
416 | | - " tt.store(dst + off, tt.exp(x), mask=mask)\n", |
417 | | - "\n", |
418 | | - "@tt.jit\n", |
419 | | - "def sm_reduce_sum(src, dst, N):\n", |
420 | | - " pid = tt.program_id(0)\n", |
421 | | - " off = pid * 64 + tt.arange(0, 64)\n", |
422 | | - " mask = off < N\n", |
423 | | - " x = tt.load(src + off, mask=mask)\n", |
424 | | - " total = tt.reduce_sum(x)\n", |
425 | | - " tt.store(dst + pid, total)\n", |
426 | | - "\n", |
427 | | - "@tt.jit\n", |
428 | | - "def sm_div_scalar(src, scalar_ptr, dst, N):\n", |
429 | | - " pid = tt.program_id(0)\n", |
430 | | - " off = pid * 64 + tt.arange(0, 64)\n", |
431 | | - " mask = off < N\n", |
432 | | - " x = tt.load(src + off, mask=mask)\n", |
433 | | - " s = tt.load(scalar_ptr)\n", |
434 | | - " tt.store(dst + off, x / s, mask=mask)\n", |
| 362 | + "def fused_softmax_kernel(src, dst, N):\n", |
| 363 | + " tid = tt.arange(0, 64)\n", |
| 364 | + " mask = tid < N\n", |
| 365 | + " x = tt.load(src + tid, mask=mask, other=-float('inf'))\n", |
| 366 | + " mx = tt.reduce_max(x)\n", |
| 367 | + " e = tt.exp(x - mx)\n", |
| 368 | + " s = tt.reduce_sum(e)\n", |
| 369 | + " tt.store(dst + tid, e / s, mask=mask)\n", |
435 | 370 | "\n", |
436 | 371 | "def gpu_softmax(x, N):\n", |
437 | | - " grid = (max(1, (N + 63) // 64),)\n", |
438 | | - " tmp_max = np.zeros(1, dtype=np.float32)\n", |
439 | | - " tmp_exp = np.zeros(N, dtype=np.float32)\n", |
440 | | - " tmp_sum = np.zeros(1, dtype=np.float32)\n", |
441 | 372 | " out = np.zeros(N, dtype=np.float32)\n", |
442 | | - " sm_reduce_max[(1,)](x, tmp_max, N)\n", |
443 | | - " sm_sub_scalar[grid](x, tmp_max, tmp_exp, N)\n", |
444 | | - " sm_exp[grid](tmp_exp, tmp_exp, N)\n", |
445 | | - " sm_reduce_sum[(1,)](tmp_exp, tmp_sum, N)\n", |
446 | | - " sm_div_scalar[grid](tmp_exp, tmp_sum, out, N)\n", |
| 373 | + " fused_softmax_kernel[(1,)](x, out, N)\n", |
447 | 374 | " return out\n", |
448 | 375 | "\n", |
| 376 | + "# --- fused scaled softmax (score/sqrt_d + softmax, 1 kernel launch) ---------\n", |
| 377 | + "\n", |
| 378 | + "@tt.jit\n", |
| 379 | + "def fused_scaled_softmax_kernel(src, dst, N, sqrt_d_ptr):\n", |
| 380 | + " tid = tt.arange(0, 64)\n", |
| 381 | + " mask = tid < N\n", |
| 382 | + " x = tt.load(src + tid, mask=mask, other=-float('inf'))\n", |
| 383 | + " sd = tt.load(sqrt_d_ptr)\n", |
| 384 | + " x = x / sd\n", |
| 385 | + " mx = tt.reduce_max(x)\n", |
| 386 | + " e = tt.exp(x - mx)\n", |
| 387 | + " s = tt.reduce_sum(e)\n", |
| 388 | + " tt.store(dst + tid, e / s, mask=mask)\n", |
| 389 | + "\n", |
449 | 390 | "# --- relu -------------------------------------------------------------------\n", |
450 | 391 | "\n", |
451 | 392 | "@tt.jit\n", |
|
512 | 453 | " gpu_linear_kernel[(seq_len,)](np.ascontiguousarray(K_h).flatten().copy(), q_h, scores, head_dim)\n", |
513 | 454 | "\n", |
514 | 455 | " sqrt_d = np.array([np.sqrt(float(head_dim))], dtype=np.float32)\n", |
515 | | - " scores_scaled = np.zeros(seq_len, dtype=np.float32)\n", |
516 | | - " sm_div_scalar[(1,)](scores, sqrt_d, scores_scaled, seq_len)\n", |
517 | | - "\n", |
518 | | - " attn_weights = gpu_softmax(scores_scaled, seq_len)\n", |
| 456 | + " attn_weights = np.zeros(seq_len, dtype=np.float32)\n", |
| 457 | + " fused_scaled_softmax_kernel[(1,)](scores, attn_weights, seq_len, sqrt_d)\n", |
519 | 458 | "\n", |
520 | 459 | " V_h_T = np.ascontiguousarray(V_h.T)\n", |
521 | 460 | " head_out = np.zeros(head_dim, dtype=np.float32)\n", |
|
0 commit comments