2424# Helpers
2525# -----------------------------------------------------------------------------
2626
27+
2728def rand_tensor (shape : Tuple [int , ...], * , scale : float = 1.0 ) -> torch .Tensor :
2829 """Random bf16 tensor of the given *shape* and *scale*."""
2930 return (torch .randn (shape , dtype = torch .float32 , device = device ) * scale ).to (dtype )
@@ -34,7 +35,13 @@ def to_bf16(t: torch.Tensor) -> torch.Tensor:
3435 return t .to (dtype )
3536
3637
37- def assert_close (a : torch .Tensor , b : torch .Tensor , * , rtol : float | None = None , atol : float | None = None ) -> None :
38+ def assert_close (
39+ a : torch .Tensor ,
40+ b : torch .Tensor ,
41+ * ,
42+ rtol : float | None = None ,
43+ atol : float | None = None ,
44+ ) -> None :
3845 """Wrapper around *torch.allclose* with nicer error messages."""
3946 rtol = default_rtol if rtol is None else rtol
4047 atol = default_atol if atol is None else atol
@@ -45,10 +52,12 @@ def assert_close(a: torch.Tensor, b: torch.Tensor, *, rtol: float | None = None,
4552 f"max abs err { delta .max ():.3e} , max rel err { (delta / a .abs ().clamp_min (1 )).max ():.3e} "
4653 )
4754
55+
4856# -----------------------------------------------------------------------------
4957# Kernels test
5058# -----------------------------------------------------------------------------
5159
60+
5261def k_sum (x : torch .Tensor , y : torch .Tensor , * , backend : str ) -> torch .Tensor :
5362 """Sum kernel → output shape (N,).
5463
@@ -64,15 +73,14 @@ def k_sum(x: torch.Tensor, y: torch.Tensor, *, backend: str) -> torch.Tensor:
6473
6574
6675def k_exp_sqnorm (x : torch .Tensor , y : torch .Tensor , * , backend : str ) -> torch .Tensor :
67- """RBF-like kernel → output shape (N,).
68-
69- """
76+ """RBF-like kernel → output shape (N,)."""
7077 if backend == "keops" :
7178 x , y = LazyTensor (x ), LazyTensor (y )
7279
7380 d2 = ((x - y ) ** 2 ).sum (dim = - 1 ) # shape (M, N)
7481 return (- d2 ).exp ().sum (dim = 0 ).squeeze ()
7582
83+
7684ALL_FUNS : list [Callable [[torch .Tensor , torch .Tensor , str ], torch .Tensor ]] = [
7785 k_sum ,
7886 k_exp_sqnorm ,
@@ -82,11 +90,17 @@ def k_exp_sqnorm(x: torch.Tensor, y: torch.Tensor, *, backend: str) -> torch.Ten
8290# Reference vs Keops helper
8391# -----------------------------------------------------------------------------
8492
85- def reference_and_keops (fun : Callable [[torch .Tensor , torch .Tensor , str ], torch .Tensor ], x : torch .Tensor , y : torch .Tensor ):
93+
94+ def reference_and_keops (
95+ fun : Callable [[torch .Tensor , torch .Tensor , str ], torch .Tensor ],
96+ x : torch .Tensor ,
97+ y : torch .Tensor ,
98+ ):
8699 ref = fun (to_bf16 (x ), to_bf16 (y ), backend = "torch" )
87100 ko = fun (to_bf16 (x ), to_bf16 (y ), backend = "keops" )
88101 return ref , ko
89102
103+
90104# -----------------------------------------------------------------------------
91105# Parameter grids – moderate sizes
92106# -----------------------------------------------------------------------------
@@ -99,6 +113,7 @@ def reference_and_keops(fun: Callable[[torch.Tensor, torch.Tensor, str], torch.T
99113# Scaling factors (keep within representable range)
100114SCALES = {"unit" : 1.0 , "tiny" : 1e-2 , "huge" : 1e2 }
101115
116+
102117# -----------------------------------------------------------------------------
103118# Forward tests
104119# -----------------------------------------------------------------------------
@@ -113,7 +128,7 @@ def test_forward(fun, M: int, N: int, D: int, scale_key: str):
113128 # ------------------------------------------------------------------
114129 # Adaptive tolerances: bf16 has ε ≈ 2**-7 ≃ 7.8e-3.
115130 # ------------------------------------------------------------------
116- bf16_eps = 2 ** - 7 # ≈7.8e-3
131+ bf16_eps = 2 ** - 7 # ≈7.8e-3
117132 terms_sqrt = (M * D ) ** 0.5
118133
119134 rtol = max (default_rtol , 2.0 * bf16_eps * terms_sqrt )
@@ -122,11 +137,13 @@ def test_forward(fun, M: int, N: int, D: int, scale_key: str):
122137 ref , ko = reference_and_keops (fun , x , y )
123138 assert_close (ref , ko , rtol = rtol , atol = atol )
124139
140+
125141# -----------------------------------------------------------------------------
126142# Backward tests on a subset of shapes
127143# -----------------------------------------------------------------------------
128144BACKWARD_SHAPES = [(5 , 5 , 4 ), (25 , 7 , 1 )]
129145
146+
130147@pytest .mark .skipif (not torch .cuda .is_available (), reason = "Requires GPU" )
131148@pytest .mark .parametrize ("fun" , ALL_FUNS )
132149@pytest .mark .parametrize ("M,N,D" , BACKWARD_SHAPES )
@@ -135,10 +152,11 @@ def test_backward(fun, M: int, N: int, D: int):
135152 y = rand_tensor ((1 , N , D ))
136153 ref , ko = reference_and_keops (fun , x , y )
137154 grad = torch .randn_like (ref )
138- g_ref , = torch .autograd .grad (ref , x , grad_outputs = grad )
139- g_ko , = torch .autograd .grad (ko , x , grad_outputs = grad )
155+ ( g_ref ,) = torch .autograd .grad (ref , x , grad_outputs = grad )
156+ ( g_ko ,) = torch .autograd .grad (ko , x , grad_outputs = grad )
140157 assert_close (g_ref , g_ko , rtol = 5e-2 , atol = 5e-2 )
141158
159+
142160# -----------------------------------------------------------------------------
143161# Gradcheck – double precision, tiny shape
144162# -----------------------------------------------------------------------------
@@ -147,26 +165,29 @@ def test_backward(fun, M: int, N: int, D: int):
147165def test_gradcheck (fun ):
148166 x = torch .randn (3 , 1 , 2 , dtype = torch .float64 , device = device , requires_grad = True )
149167 y = torch .randn (1 , 4 , 2 , dtype = torch .float64 , device = device )
150- torch .autograd .gradcheck (lambda u : fun (u , y , backend = "keops" ), (x ,), eps = 1e-6 , atol = 1e-3 , rtol = 1e-3 )
168+ torch .autograd .gradcheck (
169+ lambda u : fun (u , y , backend = "keops" ), (x ,), eps = 1e-6 , atol = 1e-3 , rtol = 1e-3
170+ )
171+
151172
152173# -----------------------------------------------------------------------------
153- # Small-shape tests
174+ # Small-shape tests
154175# -----------------------------------------------------------------------------
155176
156177SMALL_SHAPES = [(1 , 1 , 1 ), (2 , 3 , 2 ), (3 , 4 , 3 ), (10 , 10 , 4 )]
157178SMALL_IDS = [f"M{ m } _N{ n } _D{ d } " for (m , n , d ) in SMALL_SHAPES ]
158179
180+
159181@pytest .mark .skipif (not torch .cuda .is_available (), reason = "Requires GPU" )
160182@pytest .mark .parametrize ("M,N,D" , SMALL_SHAPES , ids = SMALL_IDS )
161183def test_small_sum (M : int , N : int , D : int ):
162- """Quick sanity check on small random shapes.
163-
164- """
184+ """Quick sanity check on small random shapes."""
165185 x = rand_tensor ((M , 1 , D )).requires_grad_ (True )
166186 y = rand_tensor ((1 , N , D ))
167187 ref , ko = reference_and_keops (k_sum , x , y )
168188 assert_close (ref , ko )
169189
190+
170191# -----------------------------------------------------------------------------
171192# TF32 interaction test
172193# -----------------------------------------------------------------------------
0 commit comments