Skip to content

Commit f9e0a5a

Browse files
committed
fix(tf): fix modifier_type in DeepEval
A downgrade in #3213. Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
1 parent eb474d4 commit f9e0a5a

2 files changed

Lines changed: 66 additions & 1 deletion

File tree

deepmd/tf/infer/deep_eval.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,6 @@ def __init__(
139139
self.has_fparam = self.tensors["fparam"] is not None
140140
self.has_aparam = self.tensors["aparam"] is not None
141141
self.has_spin = self.ntypes_spin > 0
142-
self.modifier_type = None
143142

144143
# looks ugly...
145144
if self.modifier_type == "dipole_charge":
@@ -201,6 +200,8 @@ def _init_tensors(self):
201200
"ntypes_spin": "spin_attr/ntypes_spin:0",
202201
# descriptor
203202
"descriptor": "o_descriptor:0",
203+
# modifier
204+
"modifier_type": "modifier_attr/type:0",
204205
}
205206
# output tensors
206207
output_tensor_names = {}
@@ -260,6 +261,10 @@ def _init_attr(self):
260261
else:
261262
self.numb_dos = 0
262263
self.tmap = tmap.decode("utf-8").split()
264+
if self.tensors["modifier_type"] is not None:
265+
self.modifier_type = run_sess(self.sess, [self.tensors["modifier_type"]])[0]
266+
else:
267+
self.modifier_type = None
263268

264269
@property
265270
@lru_cache(maxsize=None)

source/tests/tf/test_dplr.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import unittest
3+
from pathlib import (
4+
Path,
5+
)
6+
7+
import numpy as np
8+
9+
from deepmd.infer.deep_pot import (
10+
DeepPot,
11+
)
12+
from deepmd.tf.utils.convert import (
13+
convert_pbtxt_to_pb,
14+
)
15+
16+
17+
class TestDPLR(unittest.TestCase):
18+
def setUp(self):
19+
# a bit strange path, need to move to the correct directory
20+
pbtxt_file = (
21+
Path(__file__).parent.parent.parent / "lmp" / "tests" / "lrmodel.pbtxt"
22+
)
23+
convert_pbtxt_to_pb(pbtxt_file, "lrmodel.pb")
24+
25+
self.expected_e_lr_efield_variable = -40.56538550
26+
self.expected_f_lr_efield_variable = np.array(
27+
[
28+
[0.35019748, 0.27802691, -0.38443156],
29+
[-0.42115581, -0.20474826, -0.02701100],
30+
[-0.56357653, 0.34154004, 0.78389512],
31+
[0.21023870, -0.60684635, -0.39875165],
32+
[0.78732106, 0.00610023, 0.17197636],
33+
[-0.36302488, 0.18592742, -0.14567727],
34+
]
35+
)
36+
37+
self.box = np.eye(3).reshape(1, 9) * 20.0
38+
self.coord = np.array(
39+
[
40+
[1.25545000, 1.27562200, 0.98873000],
41+
[0.96101000, 3.25750000, 1.33494000],
42+
[0.66417000, 1.31153700, 1.74354000],
43+
[1.29187000, 0.33436000, 0.73085000],
44+
[1.88885000, 3.51130000, 1.42444000],
45+
[0.51617000, 4.04330000, 0.90904000],
46+
[1.25545000, 1.27562200, 0.98873000],
47+
[0.96101000, 3.25750000, 1.33494000],
48+
]
49+
).reshape(1, 8, 3)
50+
self.atype = np.array([0, 0, 1, 1, 1, 1, 2, 2])
51+
52+
def test_eval(self):
53+
dp = DeepPot("lrmodel.pb")
54+
e, f, v, ae, av = dp.eval(
55+
self.coord[:, :6], self.box, self.atype[:6], atomic=True
56+
)
57+
np.testing.assert_allclose(e, self.expected_e_lr_efield_variable, atol=1e-6)
58+
np.testing.assert_allclose(
59+
f.ravel(), self.expected_f_lr_efield_variable.ravel(), atol=1e-6
60+
)

0 commit comments

Comments
 (0)