Skip to content

Commit e88f183

Browse files
committed
chore: test consistency of rotation matrix
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
1 parent 65ca05a commit e88f183

5 files changed

Lines changed: 5 additions & 5 deletions

File tree

source/tests/consistent/descriptor/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def build_tf_descriptor(self, obj, natoms, coords, atype, box, suffix):
7474
)
7575
# ensure get_dim_out gives the correct shape
7676
t_des = tf.reshape(t_des, [1, natoms[0], obj.get_dim_out()])
77-
return [t_des], {
77+
return [t_des, obj.get_rot_mat()], {
7878
t_coord: coords,
7979
t_type: atype,
8080
t_natoms: natoms,

source/tests/consistent/descriptor/test_dpa1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -442,7 +442,7 @@ def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any:
442442
)
443443

444444
def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]:
445-
return (ret[0],)
445+
return (ret[0], ret[1])
446446

447447
@property
448448
def rtol(self) -> float:

source/tests/consistent/descriptor/test_hybrid.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,4 +168,4 @@ def eval_jax(self, jax_obj: Any) -> Any:
168168
)
169169

170170
def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]:
171-
return (ret[0],)
171+
return (ret[0], ret[1])

source/tests/consistent/descriptor/test_se_atten_v2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any:
340340
)
341341

342342
def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]:
343-
return (ret[0],)
343+
return (ret[0], ret[1])
344344

345345
@property
346346
def rtol(self) -> float:

source/tests/consistent/descriptor/test_se_e2_a.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any:
259259
)
260260

261261
def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]:
262-
return (ret[0],)
262+
return (ret[0], ret[1])
263263

264264
@property
265265
def rtol(self) -> float:

0 commit comments

Comments
 (0)