Skip to content

Commit e4e82bc

Browse files
committed
make linter happy
1 parent 53d6c11 commit e4e82bc

File tree

1 file changed

+34
-13
lines changed

1 file changed

+34
-13
lines changed

pykeops/pykeops/test/test_bf16_extensive.py

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
# Helpers
2525
# -----------------------------------------------------------------------------
2626

27+
2728
def rand_tensor(shape: Tuple[int, ...], *, scale: float = 1.0) -> torch.Tensor:
2829
"""Random bf16 tensor of the given *shape* and *scale*."""
2930
return (torch.randn(shape, dtype=torch.float32, device=device) * scale).to(dtype)
@@ -34,7 +35,13 @@ def to_bf16(t: torch.Tensor) -> torch.Tensor:
3435
return t.to(dtype)
3536

3637

37-
def assert_close(a: torch.Tensor, b: torch.Tensor, *, rtol: float | None = None, atol: float | None = None) -> None:
38+
def assert_close(
39+
a: torch.Tensor,
40+
b: torch.Tensor,
41+
*,
42+
rtol: float | None = None,
43+
atol: float | None = None,
44+
) -> None:
3845
"""Wrapper around *torch.allclose* with nicer error messages."""
3946
rtol = default_rtol if rtol is None else rtol
4047
atol = default_atol if atol is None else atol
@@ -45,10 +52,12 @@ def assert_close(a: torch.Tensor, b: torch.Tensor, *, rtol: float | None = None,
4552
f"max abs err {delta.max():.3e}, max rel err {(delta / a.abs().clamp_min(1)).max():.3e}"
4653
)
4754

55+
4856
# -----------------------------------------------------------------------------
4957
# Kernels test
5058
# -----------------------------------------------------------------------------
5159

60+
5261
def k_sum(x: torch.Tensor, y: torch.Tensor, *, backend: str) -> torch.Tensor:
5362
"""Sum kernel → output shape (N,).
5463
@@ -64,15 +73,14 @@ def k_sum(x: torch.Tensor, y: torch.Tensor, *, backend: str) -> torch.Tensor:
6473

6574

6675
def k_exp_sqnorm(x: torch.Tensor, y: torch.Tensor, *, backend: str) -> torch.Tensor:
67-
"""RBF-like kernel → output shape (N,).
68-
69-
"""
76+
"""RBF-like kernel → output shape (N,)."""
7077
if backend == "keops":
7178
x, y = LazyTensor(x), LazyTensor(y)
7279

7380
d2 = ((x - y) ** 2).sum(dim=-1) # shape (M, N)
7481
return (-d2).exp().sum(dim=0).squeeze()
7582

83+
7684
ALL_FUNS: list[Callable[[torch.Tensor, torch.Tensor, str], torch.Tensor]] = [
7785
k_sum,
7886
k_exp_sqnorm,
@@ -82,11 +90,17 @@ def k_exp_sqnorm(x: torch.Tensor, y: torch.Tensor, *, backend: str) -> torch.Ten
8290
# Reference vs Keops helper
8391
# -----------------------------------------------------------------------------
8492

85-
def reference_and_keops(fun: Callable[[torch.Tensor, torch.Tensor, str], torch.Tensor], x: torch.Tensor, y: torch.Tensor):
93+
94+
def reference_and_keops(
95+
fun: Callable[[torch.Tensor, torch.Tensor, str], torch.Tensor],
96+
x: torch.Tensor,
97+
y: torch.Tensor,
98+
):
8699
ref = fun(to_bf16(x), to_bf16(y), backend="torch")
87100
ko = fun(to_bf16(x), to_bf16(y), backend="keops")
88101
return ref, ko
89102

103+
90104
# -----------------------------------------------------------------------------
91105
# Parameter grids – moderate sizes
92106
# -----------------------------------------------------------------------------
@@ -99,6 +113,7 @@ def reference_and_keops(fun: Callable[[torch.Tensor, torch.Tensor, str], torch.T
99113
# Scaling factors (keep within representable range)
100114
SCALES = {"unit": 1.0, "tiny": 1e-2, "huge": 1e2}
101115

116+
102117
# -----------------------------------------------------------------------------
103118
# Forward tests
104119
# -----------------------------------------------------------------------------
@@ -113,7 +128,7 @@ def test_forward(fun, M: int, N: int, D: int, scale_key: str):
113128
# ------------------------------------------------------------------
114129
# Adaptive tolerances: bf16 has ε ≈ 2**-7 ≃ 7.8e-3.
115130
# ------------------------------------------------------------------
116-
bf16_eps = 2 ** -7 # ≈7.8e-3
131+
bf16_eps = 2**-7 # ≈7.8e-3
117132
terms_sqrt = (M * D) ** 0.5
118133

119134
rtol = max(default_rtol, 2.0 * bf16_eps * terms_sqrt)
@@ -122,11 +137,13 @@ def test_forward(fun, M: int, N: int, D: int, scale_key: str):
122137
ref, ko = reference_and_keops(fun, x, y)
123138
assert_close(ref, ko, rtol=rtol, atol=atol)
124139

140+
125141
# -----------------------------------------------------------------------------
126142
# Backward tests on a subset of shapes
127143
# -----------------------------------------------------------------------------
128144
BACKWARD_SHAPES = [(5, 5, 4), (25, 7, 1)]
129145

146+
130147
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires GPU")
131148
@pytest.mark.parametrize("fun", ALL_FUNS)
132149
@pytest.mark.parametrize("M,N,D", BACKWARD_SHAPES)
@@ -135,10 +152,11 @@ def test_backward(fun, M: int, N: int, D: int):
135152
y = rand_tensor((1, N, D))
136153
ref, ko = reference_and_keops(fun, x, y)
137154
grad = torch.randn_like(ref)
138-
g_ref, = torch.autograd.grad(ref, x, grad_outputs=grad)
139-
g_ko, = torch.autograd.grad(ko, x, grad_outputs=grad)
155+
(g_ref,) = torch.autograd.grad(ref, x, grad_outputs=grad)
156+
(g_ko,) = torch.autograd.grad(ko, x, grad_outputs=grad)
140157
assert_close(g_ref, g_ko, rtol=5e-2, atol=5e-2)
141158

159+
142160
# -----------------------------------------------------------------------------
143161
# Gradcheck – double precision, tiny shape
144162
# -----------------------------------------------------------------------------
@@ -147,26 +165,29 @@ def test_backward(fun, M: int, N: int, D: int):
147165
def test_gradcheck(fun):
148166
x = torch.randn(3, 1, 2, dtype=torch.float64, device=device, requires_grad=True)
149167
y = torch.randn(1, 4, 2, dtype=torch.float64, device=device)
150-
torch.autograd.gradcheck(lambda u: fun(u, y, backend="keops"), (x,), eps=1e-6, atol=1e-3, rtol=1e-3)
168+
torch.autograd.gradcheck(
169+
lambda u: fun(u, y, backend="keops"), (x,), eps=1e-6, atol=1e-3, rtol=1e-3
170+
)
171+
151172

152173
# -----------------------------------------------------------------------------
153-
# Small-shape tests
174+
# Small-shape tests
154175
# -----------------------------------------------------------------------------
155176

156177
SMALL_SHAPES = [(1, 1, 1), (2, 3, 2), (3, 4, 3), (10, 10, 4)]
157178
SMALL_IDS = [f"M{m}_N{n}_D{d}" for (m, n, d) in SMALL_SHAPES]
158179

180+
159181
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires GPU")
160182
@pytest.mark.parametrize("M,N,D", SMALL_SHAPES, ids=SMALL_IDS)
161183
def test_small_sum(M: int, N: int, D: int):
162-
"""Quick sanity check on small random shapes.
163-
164-
"""
184+
"""Quick sanity check on small random shapes."""
165185
x = rand_tensor((M, 1, D)).requires_grad_(True)
166186
y = rand_tensor((1, N, D))
167187
ref, ko = reference_and_keops(k_sum, x, y)
168188
assert_close(ref, ko)
169189

190+
170191
# -----------------------------------------------------------------------------
171192
# TF32 interaction test
172193
# -----------------------------------------------------------------------------

0 commit comments

Comments
 (0)