Skip to content

Commit fb4c1b4

Browse files
author
Han Wang
committed
fix(dpmodel): fix natoms[0] bug, einsum, and return type in EnergyLoss
- Fix natoms[0] -> natoms in generalized force branch (natoms is int) - Replace xp.einsum with array-API-compatible xp.sum + broadcasting - Fix return type annotation of Loss.call and EnergyLoss.call from dict[str, Array] to tuple[Array, dict[str, Array]] - Add TestEnerGF consistency test for generalized force code path - Add dpmodel-level unit tests for EnergyLoss (basic, aecoeff, generalized force, huber, serialize round-trip)
1 parent 24e54bf commit fb4c1b4

4 files changed

Lines changed: 403 additions & 9 deletions

File tree

deepmd/dpmodel/loss/ener.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def call(
9393
natoms: int,
9494
model_dict: dict[str, Array],
9595
label_dict: dict[str, Array],
96-
) -> dict[str, Array]:
96+
) -> tuple[Array, dict[str, Array]]:
9797
"""Calculate loss from model results and labeled results."""
9898
energy = model_dict["energy"]
9999
force = model_dict["force"]
@@ -244,15 +244,16 @@ def call(
244244
if self.has_gf:
245245
find_drdq = label_dict["find_drdq"]
246246
drdq = label_dict["drdq"]
247-
force_reshape_nframes = xp.reshape(force, (-1, natoms[0] * 3))
248-
force_hat_reshape_nframes = xp.reshape(force_hat, (-1, natoms[0] * 3))
247+
force_reshape_nframes = xp.reshape(force, (-1, natoms * 3))
248+
force_hat_reshape_nframes = xp.reshape(force_hat, (-1, natoms * 3))
249249
drdq_reshape = xp.reshape(
250-
drdq, (-1, natoms[0] * 3, self.numb_generalized_coord)
250+
drdq, (-1, natoms * 3, self.numb_generalized_coord)
251251
)
252-
gen_force_hat = xp.einsum(
253-
"bij,bi->bj", drdq_reshape, force_hat_reshape_nframes
252+
# "bij,bi->bj" einsum replaced with array-API-compatible ops
253+
gen_force_hat = xp.sum(
254+
drdq_reshape * force_hat_reshape_nframes[:, :, None], axis=1
254255
)
255-
gen_force = xp.einsum("bij,bi->bj", drdq_reshape, force_reshape_nframes)
256+
gen_force = xp.sum(drdq_reshape * force_reshape_nframes[:, :, None], axis=1)
256257
diff_gen_force = gen_force_hat - gen_force
257258
l2_gen_force_loss = xp.mean(xp.square(diff_gen_force))
258259
pref_gf = find_drdq * (

deepmd/dpmodel/loss/loss.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,16 @@ def call(
2828
natoms: int,
2929
model_dict: dict[str, Array],
3030
label_dict: dict[str, Array],
31-
) -> dict[str, Array]:
32-
"""Calculate loss from model results and labeled results."""
31+
) -> tuple[Array, dict[str, Array]]:
32+
"""Calculate loss from model results and labeled results.
33+
34+
Returns
35+
-------
36+
loss : Array
37+
The scalar loss to minimize.
38+
more_loss : dict[str, Array]
39+
Additional loss terms/metrics for logging.
40+
"""
3341

3442
@property
3543
@abstractmethod
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import unittest
3+
4+
import numpy as np
5+
6+
from deepmd.dpmodel.loss.ener import (
7+
EnergyLoss,
8+
)
9+
10+
from ...seed import (
11+
GLOBAL_SEED,
12+
)
13+
14+
15+
class TestEnergyLossBase(unittest.TestCase):
16+
"""Base class providing common setup for dpmodel EnergyLoss tests."""
17+
18+
def _make_data(self, natoms=5, nframes=2, numb_generalized_coord=0):
19+
"""Generate fake model predictions and labels."""
20+
rng = np.random.default_rng(GLOBAL_SEED)
21+
model_dict = {
22+
"energy": rng.random((nframes, 1)),
23+
"force": rng.random((nframes, natoms, 3)),
24+
"virial": rng.random((nframes, 9)),
25+
"atom_energy": rng.random((nframes, natoms, 1)),
26+
}
27+
label_dict = {
28+
"energy": rng.random((nframes, 1)),
29+
"force": rng.random((nframes, natoms, 3)),
30+
"virial": rng.random((nframes, 9)),
31+
"atom_ener": rng.random((nframes, natoms, 1)),
32+
"atom_pref": rng.random((nframes, natoms * 3)),
33+
"find_energy": 1.0,
34+
"find_force": 1.0,
35+
"find_virial": 1.0,
36+
"find_atom_ener": 1.0,
37+
"find_atom_pref": 1.0,
38+
}
39+
if numb_generalized_coord > 0:
40+
label_dict["drdq"] = rng.random(
41+
(nframes, natoms * 3 * numb_generalized_coord)
42+
)
43+
label_dict["find_drdq"] = 1.0
44+
if hasattr(self, "enable_atom_ener_coeff") and self.enable_atom_ener_coeff:
45+
label_dict["atom_ener_coeff"] = rng.random((nframes, natoms, 1))
46+
return model_dict, label_dict, natoms
47+
48+
49+
class TestEnergyLossBasic(TestEnergyLossBase):
50+
"""Test basic energy loss (e, f, v, ae)."""
51+
52+
def test_forward(self) -> None:
53+
loss_fn = EnergyLoss(
54+
starter_learning_rate=1.0,
55+
start_pref_e=1.0,
56+
limit_pref_e=0.5,
57+
start_pref_f=1.0,
58+
limit_pref_f=0.5,
59+
start_pref_v=1.0,
60+
limit_pref_v=0.5,
61+
start_pref_ae=1.0,
62+
limit_pref_ae=0.5,
63+
)
64+
model_dict, label_dict, natoms = self._make_data()
65+
loss, more_loss = loss_fn.call(1.0, natoms, model_dict, label_dict)
66+
self.assertIsNotNone(loss)
67+
self.assertIn("rmse_e", more_loss)
68+
self.assertIn("rmse_f", more_loss)
69+
self.assertIn("rmse_v", more_loss)
70+
self.assertIn("rmse_ae", more_loss)
71+
72+
73+
class TestEnergyLossAecoeff(TestEnergyLossBase):
74+
"""Test energy loss with atom_ener_coeff."""
75+
76+
enable_atom_ener_coeff = True
77+
78+
def test_forward(self) -> None:
79+
loss_fn = EnergyLoss(
80+
starter_learning_rate=1.0,
81+
start_pref_e=1.0,
82+
limit_pref_e=0.5,
83+
start_pref_f=1.0,
84+
limit_pref_f=0.5,
85+
start_pref_v=1.0,
86+
limit_pref_v=0.5,
87+
enable_atom_ener_coeff=True,
88+
)
89+
model_dict, label_dict, natoms = self._make_data()
90+
loss, more_loss = loss_fn.call(1.0, natoms, model_dict, label_dict)
91+
self.assertIsNotNone(loss)
92+
93+
94+
class TestEnergyLossGeneralizedForce(TestEnergyLossBase):
95+
"""Test energy loss with generalized force (numb_generalized_coord > 0).
96+
97+
This exercises the code path with natoms used as int scalar
98+
(not array), which previously had a natoms[0] bug.
99+
"""
100+
101+
def test_forward(self) -> None:
102+
numb_generalized_coord = 2
103+
loss_fn = EnergyLoss(
104+
starter_learning_rate=1.0,
105+
start_pref_e=1.0,
106+
limit_pref_e=0.5,
107+
start_pref_f=1.0,
108+
limit_pref_f=0.5,
109+
start_pref_v=1.0,
110+
limit_pref_v=0.5,
111+
start_pref_ae=1.0,
112+
limit_pref_ae=0.5,
113+
start_pref_pf=1.0,
114+
limit_pref_pf=0.5,
115+
start_pref_gf=1.0,
116+
limit_pref_gf=0.5,
117+
numb_generalized_coord=numb_generalized_coord,
118+
)
119+
model_dict, label_dict, natoms = self._make_data(
120+
numb_generalized_coord=numb_generalized_coord,
121+
)
122+
loss, more_loss = loss_fn.call(1.0, natoms, model_dict, label_dict)
123+
self.assertIsNotNone(loss)
124+
self.assertIn("rmse_gf", more_loss)
125+
self.assertIn("rmse_pf", more_loss)
126+
127+
128+
class TestEnergyLossHuber(TestEnergyLossBase):
129+
"""Test energy loss with Huber loss."""
130+
131+
def test_forward(self) -> None:
132+
loss_fn = EnergyLoss(
133+
starter_learning_rate=1.0,
134+
start_pref_e=1.0,
135+
limit_pref_e=0.5,
136+
start_pref_f=1.0,
137+
limit_pref_f=0.5,
138+
start_pref_v=1.0,
139+
limit_pref_v=0.5,
140+
use_huber=True,
141+
huber_delta=0.01,
142+
)
143+
model_dict, label_dict, natoms = self._make_data()
144+
loss, more_loss = loss_fn.call(1.0, natoms, model_dict, label_dict)
145+
self.assertIsNotNone(loss)
146+
147+
148+
class TestEnergyLossSerialize(TestEnergyLossBase):
149+
"""Test serialize/deserialize round-trip."""
150+
151+
def test_serialize_deserialize(self) -> None:
152+
loss_fn = EnergyLoss(
153+
starter_learning_rate=1.0,
154+
start_pref_e=1.0,
155+
limit_pref_e=0.5,
156+
start_pref_f=1.0,
157+
limit_pref_f=0.5,
158+
start_pref_v=1.0,
159+
limit_pref_v=0.5,
160+
start_pref_gf=1.0,
161+
limit_pref_gf=0.5,
162+
numb_generalized_coord=2,
163+
)
164+
data = loss_fn.serialize()
165+
loss_fn2 = EnergyLoss.deserialize(data)
166+
model_dict, label_dict, natoms = self._make_data(numb_generalized_coord=2)
167+
loss1, more1 = loss_fn.call(1.0, natoms, model_dict, label_dict)
168+
loss2, more2 = loss_fn2.call(1.0, natoms, model_dict, label_dict)
169+
np.testing.assert_allclose(loss1, loss2)
170+
for key in more1:
171+
np.testing.assert_allclose(more1[key], more2[key])
172+
173+
174+
if __name__ == "__main__":
175+
unittest.main()

0 commit comments

Comments
 (0)