Skip to content

Commit 208cabb

Browse files
committed
use fused kernel
1 parent 273fdfe commit 208cabb

1 file changed

Lines changed: 37 additions & 98 deletions

File tree

examples/microgpt_colab.ipynb

Lines changed: 37 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -337,115 +337,56 @@
337337
" gpu_linear_kernel[(out_f,)](W_np.flatten().copy(), x.copy(), y, in_f)\n",
338338
" return y\n",
339339
"\n",
340-
"# --- rmsnorm (4 kernel launches) --------------------------------------------\n",
340+
"# --- fused rmsnorm (1 kernel launch) ----------------------------------------\n",
341341
"\n",
342342
"@tt.jit\n",
343-
"def rn_square(src, dst, N):\n",
344-
" pid = tt.program_id(0)\n",
345-
" off = pid * 64 + tt.arange(0, 64)\n",
346-
" mask = off < N\n",
347-
" x = tt.load(src + off, mask=mask)\n",
348-
" tt.store(dst + off, x * x, mask=mask)\n",
349-
"\n",
350-
"@tt.jit\n",
351-
"def rn_reduce_sum(src, dst, N):\n",
352-
" pid = tt.program_id(0)\n",
353-
" off = pid * 64 + tt.arange(0, 64)\n",
354-
" mask = off < N\n",
355-
" x = tt.load(src + off, mask=mask)\n",
356-
" total = tt.reduce_sum(x)\n",
357-
" tt.store(dst + pid, total)\n",
358-
"\n",
359-
"@tt.jit\n",
360-
"def rn_rsqrt_mean(sum_ptr, n_ptr, out_ptr):\n",
361-
" tid = tt.arange(0, 64)\n",
362-
" s = tt.load(sum_ptr)\n",
363-
" n = tt.load(n_ptr)\n",
364-
" mean_eps = s / n + 1e-5\n",
365-
" scale = tt.rsqrt(mean_eps)\n",
366-
" tt.store(out_ptr, scale)\n",
367-
"\n",
368-
"@tt.jit\n",
369-
"def rn_mul_scalar(src, scalar_ptr, dst, N):\n",
370-
" pid = tt.program_id(0)\n",
371-
" off = pid * 64 + tt.arange(0, 64)\n",
372-
" mask = off < N\n",
373-
" x = tt.load(src + off, mask=mask)\n",
374-
" s = tt.load(scalar_ptr)\n",
375-
" tt.store(dst + off, x * s, mask=mask)\n",
343+
"def fused_rmsnorm_kernel(src, dst, N, n_ptr):\n",
344+
" tid = tt.arange(0, 64)\n",
345+
" mask = tid < N\n",
346+
" x = tt.load(src + tid, mask=mask)\n",
347+
" sq = x * x\n",
348+
" s = tt.reduce_sum(sq)\n",
349+
" n = tt.load(n_ptr)\n",
350+
" scale = tt.rsqrt(s / n + 1e-5)\n",
351+
" tt.store(dst + tid, x * scale, mask=mask)\n",
376352
"\n",
377353
"def gpu_rmsnorm(x, N):\n",
378-
" grid = (max(1, (N + 63) // 64),)\n",
379-
" tmp_sq = np.zeros(N, dtype=np.float32)\n",
380-
" tmp_sum = np.zeros(1, dtype=np.float32)\n",
381-
" tmp_scl = np.zeros(1, dtype=np.float32)\n",
382354
" n_arr = np.array([float(N)], dtype=np.float32)\n",
383355
" out = np.zeros(N, dtype=np.float32)\n",
384-
" rn_square[grid](x, tmp_sq, N)\n",
385-
" rn_reduce_sum[(1,)](tmp_sq, tmp_sum, N)\n",
386-
" rn_rsqrt_mean[(1,)](tmp_sum, n_arr, tmp_scl)\n",
387-
" rn_mul_scalar[grid](x, tmp_scl, out, N)\n",
356+
" fused_rmsnorm_kernel[(1,)](x, out, N, n_arr)\n",
388357
" return out\n",
389358
"\n",
390-
"# --- softmax (5 kernel launches) --------------------------------------------\n",
359+
"# --- fused softmax (1 kernel launch) ----------------------------------------\n",
391360
"\n",
392361
"@tt.jit\n",
393-
"def sm_reduce_max(src, dst, N):\n",
394-
" pid = tt.program_id(0)\n",
395-
" off = pid * 64 + tt.arange(0, 64)\n",
396-
" mask = off < N\n",
397-
" x = tt.load(src + off, mask=mask)\n",
398-
" mx = tt.reduce_max(x)\n",
399-
" tt.store(dst + pid, mx)\n",
400-
"\n",
401-
"@tt.jit\n",
402-
"def sm_sub_scalar(src, scalar_ptr, dst, N):\n",
403-
" pid = tt.program_id(0)\n",
404-
" off = pid * 64 + tt.arange(0, 64)\n",
405-
" mask = off < N\n",
406-
" x = tt.load(src + off, mask=mask)\n",
407-
" s = tt.load(scalar_ptr)\n",
408-
" tt.store(dst + off, x - s, mask=mask)\n",
409-
"\n",
410-
"@tt.jit\n",
411-
"def sm_exp(src, dst, N):\n",
412-
" pid = tt.program_id(0)\n",
413-
" off = pid * 64 + tt.arange(0, 64)\n",
414-
" mask = off < N\n",
415-
" x = tt.load(src + off, mask=mask)\n",
416-
" tt.store(dst + off, tt.exp(x), mask=mask)\n",
417-
"\n",
418-
"@tt.jit\n",
419-
"def sm_reduce_sum(src, dst, N):\n",
420-
" pid = tt.program_id(0)\n",
421-
" off = pid * 64 + tt.arange(0, 64)\n",
422-
" mask = off < N\n",
423-
" x = tt.load(src + off, mask=mask)\n",
424-
" total = tt.reduce_sum(x)\n",
425-
" tt.store(dst + pid, total)\n",
426-
"\n",
427-
"@tt.jit\n",
428-
"def sm_div_scalar(src, scalar_ptr, dst, N):\n",
429-
" pid = tt.program_id(0)\n",
430-
" off = pid * 64 + tt.arange(0, 64)\n",
431-
" mask = off < N\n",
432-
" x = tt.load(src + off, mask=mask)\n",
433-
" s = tt.load(scalar_ptr)\n",
434-
" tt.store(dst + off, x / s, mask=mask)\n",
362+
"def fused_softmax_kernel(src, dst, N):\n",
363+
" tid = tt.arange(0, 64)\n",
364+
" mask = tid < N\n",
365+
" x = tt.load(src + tid, mask=mask, other=-float('inf'))\n",
366+
" mx = tt.reduce_max(x)\n",
367+
" e = tt.exp(x - mx)\n",
368+
" s = tt.reduce_sum(e)\n",
369+
" tt.store(dst + tid, e / s, mask=mask)\n",
435370
"\n",
436371
"def gpu_softmax(x, N):\n",
437-
" grid = (max(1, (N + 63) // 64),)\n",
438-
" tmp_max = np.zeros(1, dtype=np.float32)\n",
439-
" tmp_exp = np.zeros(N, dtype=np.float32)\n",
440-
" tmp_sum = np.zeros(1, dtype=np.float32)\n",
441372
" out = np.zeros(N, dtype=np.float32)\n",
442-
" sm_reduce_max[(1,)](x, tmp_max, N)\n",
443-
" sm_sub_scalar[grid](x, tmp_max, tmp_exp, N)\n",
444-
" sm_exp[grid](tmp_exp, tmp_exp, N)\n",
445-
" sm_reduce_sum[(1,)](tmp_exp, tmp_sum, N)\n",
446-
" sm_div_scalar[grid](tmp_exp, tmp_sum, out, N)\n",
373+
" fused_softmax_kernel[(1,)](x, out, N)\n",
447374
" return out\n",
448375
"\n",
376+
"# --- fused scaled softmax (score/sqrt_d + softmax, 1 kernel launch) ---------\n",
377+
"\n",
378+
"@tt.jit\n",
379+
"def fused_scaled_softmax_kernel(src, dst, N, sqrt_d_ptr):\n",
380+
" tid = tt.arange(0, 64)\n",
381+
" mask = tid < N\n",
382+
" x = tt.load(src + tid, mask=mask, other=-float('inf'))\n",
383+
" sd = tt.load(sqrt_d_ptr)\n",
384+
" x = x / sd\n",
385+
" mx = tt.reduce_max(x)\n",
386+
" e = tt.exp(x - mx)\n",
387+
" s = tt.reduce_sum(e)\n",
388+
" tt.store(dst + tid, e / s, mask=mask)\n",
389+
"\n",
449390
"# --- relu -------------------------------------------------------------------\n",
450391
"\n",
451392
"@tt.jit\n",
@@ -512,10 +453,8 @@
512453
" gpu_linear_kernel[(seq_len,)](np.ascontiguousarray(K_h).flatten().copy(), q_h, scores, head_dim)\n",
513454
"\n",
514455
" sqrt_d = np.array([np.sqrt(float(head_dim))], dtype=np.float32)\n",
515-
" scores_scaled = np.zeros(seq_len, dtype=np.float32)\n",
516-
" sm_div_scalar[(1,)](scores, sqrt_d, scores_scaled, seq_len)\n",
517-
"\n",
518-
" attn_weights = gpu_softmax(scores_scaled, seq_len)\n",
456+
" attn_weights = np.zeros(seq_len, dtype=np.float32)\n",
457+
" fused_scaled_softmax_kernel[(1,)](scores, attn_weights, seq_len, sqrt_d)\n",
519458
"\n",
520459
" V_h_T = np.ascontiguousarray(V_h.T)\n",
521460
" head_out = np.zeros(head_dim, dtype=np.float32)\n",

0 commit comments

Comments
 (0)