Skip to content

Commit 6a4bf68

Browse files
shyuepclaude
andauthored
Port message-passing convention fix to warp TensorNet (#766)
The warp embedding (_embedding_warp.py) was aggregating radial messages onto destination nodes via col_data/col_indptr, the old convention corrected on main in PRs #758/#759 for the PyG/DGL implementations. Switch to row_data/row_indptr so warp matches the corrected aggregation onto source nodes, and add a parity test that loads the pretrained TensorNet-PES-MatPES-PBE-2025.2 weights into both warp and non-warp TensorNet instances and asserts identical outputs. Re-calibrate the existing warp regression-check expected values affected by the fix. Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 435e448 commit 6a4bf68

3 files changed

Lines changed: 61 additions & 20 deletions

File tree

src/matgl/layers/_embedding_warp.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,8 @@ def forward(
115115
edge_weight: torch.Tensor,
116116
edge_vec: torch.Tensor,
117117
edge_attr: torch.Tensor,
118-
col_data: torch.Tensor,
119-
col_indptr: torch.Tensor,
118+
row_data: torch.Tensor,
119+
row_indptr: torch.Tensor,
120120
) -> torch.Tensor:
121121
"""Forward pass.
122122
@@ -126,8 +126,8 @@ def forward(
126126
edge_weight: Edge weights (distances), shape (num_edges,)
127127
edge_vec: Edge vectors, shape (num_edges, 3)
128128
edge_attr: Edge attributes (RBF), shape (num_edges, num_rbf)
129-
col_data: CSC col data for destination aggregation, shape (num_edges,)
130-
col_indptr: CSC col indptr for destination aggregation, shape (num_nodes+1,)
129+
row_data: CSR row data for source aggregation, shape (num_edges,)
130+
row_indptr: CSR row indptr for source aggregation, shape (num_nodes+1,)
131131
132132
Returns:
133133
X: Tensor representation, shape (num_nodes, 3, 3, units)
@@ -143,7 +143,7 @@ def forward(
143143
edge_attr_processed = edge_attr.view(-1, 3, self.units) * C.view(-1, 1, 1) * Zij.view(-1, 1, self.units)
144144

145145
edge_vec_norm = edge_vec / torch.norm(edge_vec, dim=1, keepdim=True).clamp(min=1e-6)
146-
I, A, S = fn_radial_message_passing(edge_vec_norm, edge_attr_processed, col_data, col_indptr) # noqa: E741
146+
I, A, S = fn_radial_message_passing(edge_vec_norm, edge_attr_processed, row_data, row_indptr) # noqa: E741
147147

148148
X = fn_compose_tensor(I, A, S) # (num_nodes, 3, 3, units)
149149

src/matgl/models/_tensornet_pyg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ def forward_features(
283283
col_indices,
284284
col_indptr,
285285
) = graph_transform(edge_index.int(), z.shape[0]) # type: ignore[union-attr]
286-
X = self.tensor_embedding(z, edge_index, bond_dist, bond_vec, edge_attr, col_data, col_indptr)
286+
X = self.tensor_embedding(z, edge_index, bond_dist, bond_vec, edge_attr, row_data, row_indptr)
287287
fea_dict["embedding"] = X
288288
for i, layer in enumerate(self.layers):
289289
X = layer(

tests/models/test_tensornet_warp.py

Lines changed: 55 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import os
4+
from pathlib import Path
45

56
import numpy as np
67
import pytest
@@ -12,6 +13,7 @@
1213
pytest.skip("Skipping PYG tests", allow_module_level=True)
1314

1415
from matgl.models._tensornet_pyg import TensorNet, _warp_available
16+
from matgl.utils.io import _get_file_paths
1517

1618
if not _warp_available:
1719
pytest.skip("Skipping warp tests: nvalchemiops not installed", allow_module_level=True)
@@ -23,11 +25,11 @@ def test_model(graph_MoS_pyg):
2325

2426
# Optional regression-check values
2527
EXPECTED = {
26-
"swish": torch.tensor(0.0813),
27-
"tanh": torch.tensor(-0.0189),
28-
"sigmoid": torch.tensor(0.0353),
29-
"softplus2": torch.tensor(0.1164),
30-
"softexp": torch.tensor(0.1148),
28+
"swish": torch.tensor(0.0827),
29+
"tanh": torch.tensor(-0.0258),
30+
"sigmoid": torch.tensor(0.0360),
31+
"softplus2": torch.tensor(0.1165),
32+
"softexp": torch.tensor(0.1100),
3133
}
3234

3335
_, graph, _ = graph_MoS_pyg
@@ -78,7 +80,7 @@ def test_model_intensive(graph_MoS_pyg):
7880
graph.pos = graph.frac_coords @ lat[0]
7981
model = TensorNet(element_types=["Mo", "S"], is_intensive=True)
8082
output = model(g=graph)
81-
assert torch.allclose(output, torch.tensor([-0.0897]), atol=1e-4)
83+
assert torch.allclose(output, torch.tensor([-0.0906]), atol=1e-4)
8284

8385

8486
def test_model_intensive_with_weighted_atom(graph_MoS_pyg):
@@ -88,7 +90,7 @@ def test_model_intensive_with_weighted_atom(graph_MoS_pyg):
8890
graph.pos = graph.frac_coords @ lat[0]
8991
model = TensorNet(element_types=["Mo", "S"], is_intensive=True, readout_type="weighted_atom")
9092
output = model(g=graph)
91-
assert torch.allclose(output, torch.tensor([-0.0217]), atol=1e-4)
93+
assert torch.allclose(output, torch.tensor([-0.0210]), atol=1e-4)
9294

9395

9496
def test_model_intensive_with_ReduceReadOut(graph_MoS_pyg):
@@ -98,7 +100,7 @@ def test_model_intensive_with_ReduceReadOut(graph_MoS_pyg):
98100
graph.pos = graph.frac_coords @ lat[0]
99101
model = TensorNet(is_intensive=True, readout_type="reduce_atom")
100102
output = model(g=graph)
101-
assert torch.allclose(output, torch.tensor([-0.1045]), atol=1e-4)
103+
assert torch.allclose(output, torch.tensor([-0.1075]), atol=1e-4)
102104

103105

104106
def test_model_intensive_with_classification(graph_MoS_pyg):
@@ -122,9 +124,9 @@ def test_backward(graph_MoS_pyg):
122124

123125
EXPECTED_CELL_GRAD = torch.tensor(
124126
[
125-
[-0.000967, 0.000000, 0.000000],
126-
[0.000000, -0.000967, 0.000000],
127-
[0.000000, 0.000000, -0.000967],
127+
[-0.000909, 0.000000, 0.000000],
128+
[0.000000, -0.000909, 0.000000],
129+
[0.000000, 0.000000, -0.000909],
128130
]
129131
)
130132

@@ -150,9 +152,9 @@ def test_double_backward(graph_MoS_pyg):
150152

151153
EXPECTED_CELL_GRAD2 = torch.tensor(
152154
[
153-
[-0.000010, -0.000000, -0.000000],
154-
[-0.000000, -0.000010, -0.000000],
155-
[-0.000000, -0.000000, -0.000010],
155+
[-0.0000037, -0.000000, -0.000000],
156+
[-0.000000, -0.0000037, -0.000000],
157+
[-0.000000, -0.000000, -0.0000037],
156158
]
157159
)
158160

@@ -172,3 +174,42 @@ def test_double_backward(graph_MoS_pyg):
172174
loss.backward()
173175

174176
assert torch.allclose(cell.grad, EXPECTED_CELL_GRAD2, atol=1e-6)
177+
178+
179+
def _build_pair_from_pretrained(repo_id: str) -> tuple[TensorNet, TensorNet]:
180+
"""Build a (warp, non-warp) pair of TensorNet models loaded with identical pretrained weights."""
181+
fpaths = _get_file_paths(Path(repo_id))
182+
map_location = "cpu" if not torch.cuda.is_available() else None
183+
state = torch.load(fpaths["state.pt"], map_location=map_location, weights_only=False)
184+
init_blob = torch.load(fpaths["model.pt"], map_location=map_location, weights_only=False)
185+
inner_init_args = dict(init_blob["model"]["init_args"])
186+
187+
inner_state = {k[len("model.") :]: v for k, v in state.items() if k.startswith("model.")}
188+
189+
model_warp = TensorNet(**{**inner_init_args, "use_warp": True})
190+
model_pyg = TensorNet(**{**inner_init_args, "use_warp": False})
191+
model_warp.load_state_dict(inner_state, strict=False)
192+
model_pyg.load_state_dict(inner_state, strict=False)
193+
model_warp.eval()
194+
model_pyg.eval()
195+
return model_warp, model_pyg
196+
197+
198+
def test_warp_pyg_parity_pretrained(MoS):
199+
"""Warp and non-warp TensorNet must produce identical outputs from the same pretrained weights."""
200+
model_warp, model_pyg = _build_pair_from_pretrained("materialyze/TensorNet-PES-MatPES-PBE-2025.2")
201+
202+
from matgl.ext._pymatgen_pyg import Structure2Graph
203+
204+
converter = Structure2Graph(element_types=model_pyg.element_types, cutoff=model_pyg.cutoff)
205+
g, lat, _ = converter.get_graph(MoS)
206+
g.pbc_offshift = torch.matmul(g.pbc_offset, lat[0])
207+
g.pos = g.frac_coords @ lat[0]
208+
209+
with torch.no_grad():
210+
out_warp = model_warp(g=g)
211+
out_pyg = model_pyg(g=g)
212+
213+
assert torch.allclose(out_warp, out_pyg, atol=1e-5, rtol=1e-5), (
214+
f"warp={out_warp.detach().cpu()} vs pyg={out_pyg.detach().cpu()}"
215+
)

0 commit comments

Comments
 (0)