11from __future__ import annotations
22
33import os
4+ from pathlib import Path
45
56import numpy as np
67import pytest
1213 pytest .skip ("Skipping PYG tests" , allow_module_level = True )
1314
1415from matgl .models ._tensornet_pyg import TensorNet , _warp_available
16+ from matgl .utils .io import _get_file_paths
1517
1618if not _warp_available :
1719 pytest .skip ("Skipping warp tests: nvalchemiops not installed" , allow_module_level = True )
@@ -23,11 +25,11 @@ def test_model(graph_MoS_pyg):
2325
2426 # Optional regression-check values
2527 EXPECTED = {
26- "swish" : torch .tensor (0.0813 ),
27- "tanh" : torch .tensor (- 0.0189 ),
28- "sigmoid" : torch .tensor (0.0353 ),
29- "softplus2" : torch .tensor (0.1164 ),
30- "softexp" : torch .tensor (0.1148 ),
28+ "swish" : torch .tensor (0.0827 ),
29+ "tanh" : torch .tensor (- 0.0258 ),
30+ "sigmoid" : torch .tensor (0.0360 ),
31+ "softplus2" : torch .tensor (0.1165 ),
32+ "softexp" : torch .tensor (0.1100 ),
3133 }
3234
3335 _ , graph , _ = graph_MoS_pyg
@@ -78,7 +80,7 @@ def test_model_intensive(graph_MoS_pyg):
7880 graph .pos = graph .frac_coords @ lat [0 ]
7981 model = TensorNet (element_types = ["Mo" , "S" ], is_intensive = True )
8082 output = model (g = graph )
81- assert torch .allclose (output , torch .tensor ([- 0.0897 ]), atol = 1e-4 )
83+ assert torch .allclose (output , torch .tensor ([- 0.0906 ]), atol = 1e-4 )
8284
8385
8486def test_model_intensive_with_weighted_atom (graph_MoS_pyg ):
@@ -88,7 +90,7 @@ def test_model_intensive_with_weighted_atom(graph_MoS_pyg):
8890 graph .pos = graph .frac_coords @ lat [0 ]
8991 model = TensorNet (element_types = ["Mo" , "S" ], is_intensive = True , readout_type = "weighted_atom" )
9092 output = model (g = graph )
91- assert torch .allclose (output , torch .tensor ([- 0.0217 ]), atol = 1e-4 )
93+ assert torch .allclose (output , torch .tensor ([- 0.0210 ]), atol = 1e-4 )
9294
9395
9496def test_model_intensive_with_ReduceReadOut (graph_MoS_pyg ):
@@ -98,7 +100,7 @@ def test_model_intensive_with_ReduceReadOut(graph_MoS_pyg):
98100 graph .pos = graph .frac_coords @ lat [0 ]
99101 model = TensorNet (is_intensive = True , readout_type = "reduce_atom" )
100102 output = model (g = graph )
101- assert torch .allclose (output , torch .tensor ([- 0.1045 ]), atol = 1e-4 )
103+ assert torch .allclose (output , torch .tensor ([- 0.1075 ]), atol = 1e-4 )
102104
103105
104106def test_model_intensive_with_classification (graph_MoS_pyg ):
@@ -122,9 +124,9 @@ def test_backward(graph_MoS_pyg):
122124
123125 EXPECTED_CELL_GRAD = torch .tensor (
124126 [
125- [- 0.000967 , 0.000000 , 0.000000 ],
126- [0.000000 , - 0.000967 , 0.000000 ],
127- [0.000000 , 0.000000 , - 0.000967 ],
127+ [- 0.000909 , 0.000000 , 0.000000 ],
128+ [0.000000 , - 0.000909 , 0.000000 ],
129+ [0.000000 , 0.000000 , - 0.000909 ],
128130 ]
129131 )
130132
@@ -150,9 +152,9 @@ def test_double_backward(graph_MoS_pyg):
150152
151153 EXPECTED_CELL_GRAD2 = torch .tensor (
152154 [
153- [- 0.000010 , - 0.000000 , - 0.000000 ],
154- [- 0.000000 , - 0.000010 , - 0.000000 ],
155- [- 0.000000 , - 0.000000 , - 0.000010 ],
155+ [- 0.0000037 , - 0.000000 , - 0.000000 ],
156+ [- 0.000000 , - 0.0000037 , - 0.000000 ],
157+ [- 0.000000 , - 0.000000 , - 0.0000037 ],
156158 ]
157159 )
158160
@@ -172,3 +174,42 @@ def test_double_backward(graph_MoS_pyg):
172174 loss .backward ()
173175
174176 assert torch .allclose (cell .grad , EXPECTED_CELL_GRAD2 , atol = 1e-6 )
177+
178+
179+ def _build_pair_from_pretrained (repo_id : str ) -> tuple [TensorNet , TensorNet ]:
180+ """Build a (warp, non-warp) pair of TensorNet models loaded with identical pretrained weights."""
181+ fpaths = _get_file_paths (Path (repo_id ))
182+ map_location = "cpu" if not torch .cuda .is_available () else None
183+ state = torch .load (fpaths ["state.pt" ], map_location = map_location , weights_only = False )
184+ init_blob = torch .load (fpaths ["model.pt" ], map_location = map_location , weights_only = False )
185+ inner_init_args = dict (init_blob ["model" ]["init_args" ])
186+
187+ inner_state = {k [len ("model." ) :]: v for k , v in state .items () if k .startswith ("model." )}
188+
189+ model_warp = TensorNet (** {** inner_init_args , "use_warp" : True })
190+ model_pyg = TensorNet (** {** inner_init_args , "use_warp" : False })
191+ model_warp .load_state_dict (inner_state , strict = False )
192+ model_pyg .load_state_dict (inner_state , strict = False )
193+ model_warp .eval ()
194+ model_pyg .eval ()
195+ return model_warp , model_pyg
196+
197+
198+ def test_warp_pyg_parity_pretrained (MoS ):
199+ """Warp and non-warp TensorNet must produce identical outputs from the same pretrained weights."""
200+ model_warp , model_pyg = _build_pair_from_pretrained ("materialyze/TensorNet-PES-MatPES-PBE-2025.2" )
201+
202+ from matgl .ext ._pymatgen_pyg import Structure2Graph
203+
204+ converter = Structure2Graph (element_types = model_pyg .element_types , cutoff = model_pyg .cutoff )
205+ g , lat , _ = converter .get_graph (MoS )
206+ g .pbc_offshift = torch .matmul (g .pbc_offset , lat [0 ])
207+ g .pos = g .frac_coords @ lat [0 ]
208+
209+ with torch .no_grad ():
210+ out_warp = model_warp (g = g )
211+ out_pyg = model_pyg (g = g )
212+
213+ assert torch .allclose (out_warp , out_pyg , atol = 1e-5 , rtol = 1e-5 ), (
214+ f"warp={ out_warp .detach ().cpu ()} vs pyg={ out_pyg .detach ().cpu ()} "
215+ )
0 commit comments