diff --git a/deepmd/dpmodel/loss/ener.py b/deepmd/dpmodel/loss/ener.py index 9ab141bdc2..4eeeafb3f0 100644 --- a/deepmd/dpmodel/loss/ener.py +++ b/deepmd/dpmodel/loss/ener.py @@ -93,7 +93,7 @@ def call( natoms: int, model_dict: dict[str, Array], label_dict: dict[str, Array], - ) -> dict[str, Array]: + ) -> tuple[Array, dict[str, Array]]: """Calculate loss from model results and labeled results.""" energy = model_dict["energy"] force = model_dict["force"] @@ -244,15 +244,16 @@ def call( if self.has_gf: find_drdq = label_dict["find_drdq"] drdq = label_dict["drdq"] - force_reshape_nframes = xp.reshape(force, (-1, natoms[0] * 3)) - force_hat_reshape_nframes = xp.reshape(force_hat, (-1, natoms[0] * 3)) + force_reshape_nframes = xp.reshape(force, (-1, natoms * 3)) + force_hat_reshape_nframes = xp.reshape(force_hat, (-1, natoms * 3)) drdq_reshape = xp.reshape( - drdq, (-1, natoms[0] * 3, self.numb_generalized_coord) + drdq, (-1, natoms * 3, self.numb_generalized_coord) ) - gen_force_hat = xp.einsum( - "bij,bi->bj", drdq_reshape, force_hat_reshape_nframes + # "bij,bi->bj" einsum replaced with array-API-compatible ops + gen_force_hat = xp.sum( + drdq_reshape * force_hat_reshape_nframes[:, :, None], axis=1 ) - gen_force = xp.einsum("bij,bi->bj", drdq_reshape, force_reshape_nframes) + gen_force = xp.sum(drdq_reshape * force_reshape_nframes[:, :, None], axis=1) diff_gen_force = gen_force_hat - gen_force l2_gen_force_loss = xp.mean(xp.square(diff_gen_force)) pref_gf = find_drdq * ( diff --git a/deepmd/dpmodel/loss/loss.py b/deepmd/dpmodel/loss/loss.py index 4b9831c344..05878deabc 100644 --- a/deepmd/dpmodel/loss/loss.py +++ b/deepmd/dpmodel/loss/loss.py @@ -28,8 +28,16 @@ def call( natoms: int, model_dict: dict[str, Array], label_dict: dict[str, Array], - ) -> dict[str, Array]: - """Calculate loss from model results and labeled results.""" + ) -> tuple[Array, dict[str, Array]]: + """Calculate loss from model results and labeled results. + + Returns + ------- + loss : Array + The scalar loss to minimize. + more_loss : dict[str, Array] + Additional loss terms/metrics for logging. + """ @property @abstractmethod diff --git a/source/tests/common/dpmodel/test_loss_ener.py b/source/tests/common/dpmodel/test_loss_ener.py new file mode 100644 index 0000000000..ebf9ba0a64 --- /dev/null +++ b/source/tests/common/dpmodel/test_loss_ener.py @@ -0,0 +1,175 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +import numpy as np + +from deepmd.dpmodel.loss.ener import ( + EnergyLoss, +) + +from ...seed import ( + GLOBAL_SEED, +) + + +class TestEnergyLossBase(unittest.TestCase): + """Base class providing common setup for dpmodel EnergyLoss tests.""" + + def _make_data(self, natoms=5, nframes=2, numb_generalized_coord=0): + """Generate fake model predictions and labels.""" + rng = np.random.default_rng(GLOBAL_SEED) + model_dict = { + "energy": rng.random((nframes, 1)), + "force": rng.random((nframes, natoms, 3)), + "virial": rng.random((nframes, 9)), + "atom_energy": rng.random((nframes, natoms, 1)), + } + label_dict = { + "energy": rng.random((nframes, 1)), + "force": rng.random((nframes, natoms, 3)), + "virial": rng.random((nframes, 9)), + "atom_ener": rng.random((nframes, natoms, 1)), + "atom_pref": rng.random((nframes, natoms * 3)), + "find_energy": 1.0, + "find_force": 1.0, + "find_virial": 1.0, + "find_atom_ener": 1.0, + "find_atom_pref": 1.0, + } + if numb_generalized_coord > 0: + label_dict["drdq"] = rng.random( + (nframes, natoms * 3 * numb_generalized_coord) + ) + label_dict["find_drdq"] = 1.0 + if hasattr(self, "enable_atom_ener_coeff") and self.enable_atom_ener_coeff: + label_dict["atom_ener_coeff"] = rng.random((nframes, natoms, 1)) + return model_dict, label_dict, natoms + + +class TestEnergyLossBasic(TestEnergyLossBase): + """Test basic energy loss (e, f, v, ae).""" + + def test_forward(self) -> None: + loss_fn = EnergyLoss( + starter_learning_rate=1.0, + start_pref_e=1.0, + limit_pref_e=0.5, + start_pref_f=1.0, + limit_pref_f=0.5, + start_pref_v=1.0, + limit_pref_v=0.5, + start_pref_ae=1.0, + limit_pref_ae=0.5, + ) + model_dict, label_dict, natoms = self._make_data() + loss, more_loss = loss_fn.call(1.0, natoms, model_dict, label_dict) + self.assertIsNotNone(loss) + self.assertIn("rmse_e", more_loss) + self.assertIn("rmse_f", more_loss) + self.assertIn("rmse_v", more_loss) + self.assertIn("rmse_ae", more_loss) + + +class TestEnergyLossAecoeff(TestEnergyLossBase): + """Test energy loss with atom_ener_coeff.""" + + enable_atom_ener_coeff = True + + def test_forward(self) -> None: + loss_fn = EnergyLoss( + starter_learning_rate=1.0, + start_pref_e=1.0, + limit_pref_e=0.5, + start_pref_f=1.0, + limit_pref_f=0.5, + start_pref_v=1.0, + limit_pref_v=0.5, + enable_atom_ener_coeff=True, + ) + model_dict, label_dict, natoms = self._make_data() + loss, more_loss = loss_fn.call(1.0, natoms, model_dict, label_dict) + self.assertIsNotNone(loss) + + +class TestEnergyLossGeneralizedForce(TestEnergyLossBase): + """Test energy loss with generalized force (numb_generalized_coord > 0). + + This exercises the code path with natoms used as int scalar + (not array), which previously had a natoms[0] bug. + """ + + def test_forward(self) -> None: + numb_generalized_coord = 2 + loss_fn = EnergyLoss( + starter_learning_rate=1.0, + start_pref_e=1.0, + limit_pref_e=0.5, + start_pref_f=1.0, + limit_pref_f=0.5, + start_pref_v=1.0, + limit_pref_v=0.5, + start_pref_ae=1.0, + limit_pref_ae=0.5, + start_pref_pf=1.0, + limit_pref_pf=0.5, + start_pref_gf=1.0, + limit_pref_gf=0.5, + numb_generalized_coord=numb_generalized_coord, + ) + model_dict, label_dict, natoms = self._make_data( + numb_generalized_coord=numb_generalized_coord, + ) + loss, more_loss = loss_fn.call(1.0, natoms, model_dict, label_dict) + self.assertIsNotNone(loss) + self.assertIn("rmse_gf", more_loss) + self.assertIn("rmse_pf", more_loss) + + +class TestEnergyLossHuber(TestEnergyLossBase): + """Test energy loss with Huber loss.""" + + def test_forward(self) -> None: + loss_fn = EnergyLoss( + starter_learning_rate=1.0, + start_pref_e=1.0, + limit_pref_e=0.5, + start_pref_f=1.0, + limit_pref_f=0.5, + start_pref_v=1.0, + limit_pref_v=0.5, + use_huber=True, + huber_delta=0.01, + ) + model_dict, label_dict, natoms = self._make_data() + loss, more_loss = loss_fn.call(1.0, natoms, model_dict, label_dict) + self.assertIsNotNone(loss) + + +class TestEnergyLossSerialize(TestEnergyLossBase): + """Test serialize/deserialize round-trip.""" + + def test_serialize_deserialize(self) -> None: + loss_fn = EnergyLoss( + starter_learning_rate=1.0, + start_pref_e=1.0, + limit_pref_e=0.5, + start_pref_f=1.0, + limit_pref_f=0.5, + start_pref_v=1.0, + limit_pref_v=0.5, + start_pref_gf=1.0, + limit_pref_gf=0.5, + numb_generalized_coord=2, + ) + data = loss_fn.serialize() + loss_fn2 = EnergyLoss.deserialize(data) + model_dict, label_dict, natoms = self._make_data(numb_generalized_coord=2) + loss1, more1 = loss_fn.call(1.0, natoms, model_dict, label_dict) + loss2, more2 = loss_fn2.call(1.0, natoms, model_dict, label_dict) + np.testing.assert_allclose(loss1, loss2) + for key in more1: + np.testing.assert_allclose(more1[key], more2[key]) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/consistent/loss/test_ener.py b/source/tests/consistent/loss/test_ener.py index 1cc662fc5b..008aa9892c 100644 --- a/source/tests/consistent/loss/test_ener.py +++ b/source/tests/consistent/loss/test_ener.py @@ -272,3 +272,213 @@ def rtol(self) -> float: def atol(self) -> float: """Absolute tolerance for comparing the return value.""" return 1e-10 + + +class TestEnerGF(CommonTest, LossTest, unittest.TestCase): + """Test energy loss with generalized force (numb_generalized_coord > 0). + + This exercises the code path that previously had a natoms[0] bug. + """ + + @property + def data(self) -> dict: + return { + "start_pref_e": 0.02, + "limit_pref_e": 1.0, + "start_pref_f": 1000.0, + "limit_pref_f": 1.0, + "start_pref_v": 1.0, + "limit_pref_v": 1.0, + "start_pref_ae": 1.0, + "limit_pref_ae": 1.0, + "start_pref_pf": 1.0, + "limit_pref_pf": 1.0, + "start_pref_gf": 1.0, + "limit_pref_gf": 1.0, + "numb_generalized_coord": 2, + } + + skip_tf = CommonTest.skip_tf + skip_pt = CommonTest.skip_pt + skip_jax = not INSTALLED_JAX + skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT + skip_pd = not INSTALLED_PD + + tf_class = EnerLossTF + dp_class = EnerLossDP + pt_class = EnerLossPT + jax_class = EnerLossDP + pd_class = EnerLossPD + array_api_strict_class = EnerLossDP + args = loss_ener() + + def setUp(self) -> None: + CommonTest.setUp(self) + self.learning_rate = 1e-3 + rng = np.random.default_rng(20250105) + self.nframes = 2 + self.natoms = 6 + numb_generalized_coord = 2 + self.predict = { + "energy": rng.random((self.nframes,)), + "force": rng.random((self.nframes, self.natoms, 3)), + "virial": rng.random((self.nframes, 9)), + "atom_ener": rng.random((self.nframes, self.natoms)), + } + self.predict_dpmodel_style = { + "energy": self.predict["energy"], + "force": self.predict["force"], + "virial": self.predict["virial"], + "atom_energy": self.predict["atom_ener"], + } + self.label = { + "energy": rng.random((self.nframes,)), + "force": rng.random((self.nframes, self.natoms, 3)), + "virial": rng.random((self.nframes, 9)), + "atom_ener": rng.random((self.nframes, self.natoms)), + "atom_pref": np.ones((self.nframes, self.natoms, 3)), + "drdq": rng.random( + (self.nframes, self.natoms * 3 * numb_generalized_coord) + ), + "find_energy": 1.0, + "find_force": 1.0, + "find_virial": 1.0, + "find_atom_ener": 1.0, + "find_atom_pref": 1.0, + "find_drdq": 1.0, + } + + @property + def additional_data(self) -> dict: + return { + "starter_learning_rate": 1e-3, + } + + def build_tf(self, obj: Any, suffix: str) -> tuple[list, dict]: + predict = { + kk: tf.placeholder( + GLOBAL_TF_FLOAT_PRECISION, vv.shape, name="i_predict_" + kk + ) + for kk, vv in self.predict.items() + } + label = { + kk: tf.placeholder( + GLOBAL_TF_FLOAT_PRECISION, vv.shape, name="i_label_" + kk + ) + if isinstance(vv, np.ndarray) + else vv + for kk, vv in self.label.items() + } + + loss, more_loss = obj.build( + self.learning_rate, + [self.natoms], + predict, + label, + suffix=suffix, + ) + return [loss], { + **{ + vv: self.predict[kk] + for kk, vv in predict.items() + if isinstance(vv, tf.Tensor) + }, + **{ + vv: self.label[kk] + for kk, vv in label.items() + if isinstance(vv, tf.Tensor) + }, + } + + def eval_pt(self, pt_obj: Any) -> Any: + predict = {kk: numpy_to_torch(vv) for kk, vv in self.predict.items()} + label = {kk: numpy_to_torch(vv) for kk, vv in self.label.items()} + predict["atom_energy"] = predict.pop("atom_ener") + _, loss, more_loss = pt_obj( + {}, + lambda: predict, + label, + self.natoms, + self.learning_rate, + ) + loss = torch_to_numpy(loss) + more_loss = {kk: torch_to_numpy(vv) for kk, vv in more_loss.items()} + return loss, more_loss + + def eval_dp(self, dp_obj: Any) -> Any: + return dp_obj( + self.learning_rate, + self.natoms, + self.predict_dpmodel_style, + self.label, + ) + + def eval_jax(self, jax_obj: Any) -> Any: + predict = {kk: jnp.asarray(vv) for kk, vv in self.predict_dpmodel_style.items()} + label = {kk: jnp.asarray(vv) for kk, vv in self.label.items()} + + loss, more_loss = jax_obj( + self.learning_rate, + self.natoms, + predict, + label, + ) + loss = to_numpy_array(loss) + more_loss = {kk: to_numpy_array(vv) for kk, vv in more_loss.items()} + return loss, more_loss + + def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: + predict = { + kk: array_api_strict.asarray(vv) + for kk, vv in self.predict_dpmodel_style.items() + } + label = {kk: array_api_strict.asarray(vv) for kk, vv in self.label.items()} + + loss, more_loss = array_api_strict_obj( + self.learning_rate, + self.natoms, + predict, + label, + ) + loss = to_numpy_array(loss) + more_loss = {kk: to_numpy_array(vv) for kk, vv in more_loss.items()} + return loss, more_loss + + def eval_pd(self, pd_obj: Any) -> Any: + predict = { + kk: paddle.to_tensor(vv).to(PD_DEVICE) for kk, vv in self.predict.items() + } + label = { + kk: paddle.to_tensor(vv).to(PD_DEVICE) for kk, vv in self.label.items() + } + predict["atom_energy"] = predict.pop("atom_ener") + _, loss, more_loss = pd_obj( + {}, + lambda: predict, + label, + self.natoms, + self.learning_rate, + ) + loss = to_numpy_array(loss) + more_loss = {kk: to_numpy_array(vv) for kk, vv in more_loss.items()} + return loss, more_loss + + def extract_ret(self, ret: Any, backend) -> dict[str, np.ndarray]: + loss = ret[0] + result = {"loss": np.atleast_1d(np.asarray(loss, dtype=np.float64))} + if len(ret) > 1: + more_loss = ret[1] + for k in sorted(more_loss): + if k.startswith("rmse_"): + result[k] = np.atleast_1d( + np.asarray(more_loss[k], dtype=np.float64) + ) + return result + + @property + def rtol(self) -> float: + return 1e-10 + + @property + def atol(self) -> float: + return 1e-10