From 9126c36e83f8cdf564d810a31e62b00b9d0a51d0 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 10 Oct 2024 18:03:03 -0400 Subject: [PATCH] feat(jax/array-api): energy fitting Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/fitting/general_fitting.py | 48 +++++++------ deepmd/dpmodel/utils/exclude_mask.py | 8 ++- deepmd/jax/fitting/__init__.py | 1 + deepmd/jax/fitting/fitting.py | 39 +++++++++++ deepmd/jax/utils/exclude_mask.py | 9 +++ .../array_api_strict/fitting/__init__.py | 1 + .../tests/array_api_strict/fitting/fitting.py | 38 +++++++++++ .../array_api_strict/utils/exclude_mask.py | 8 +++ source/tests/consistent/fitting/test_ener.py | 67 +++++++++++++++++++ 9 files changed, 197 insertions(+), 22 deletions(-) create mode 100644 deepmd/jax/fitting/__init__.py create mode 100644 deepmd/jax/fitting/fitting.py create mode 100644 source/tests/array_api_strict/fitting/__init__.py create mode 100644 source/tests/array_api_strict/fitting/fitting.py diff --git a/deepmd/dpmodel/fitting/general_fitting.py b/deepmd/dpmodel/fitting/general_fitting.py index a587f69449..fd80ccb4aa 100644 --- a/deepmd/dpmodel/fitting/general_fitting.py +++ b/deepmd/dpmodel/fitting/general_fitting.py @@ -9,12 +9,16 @@ Union, ) +import array_api_compat import numpy as np from deepmd.dpmodel import ( DEFAULT_PRECISION, NativeOP, ) +from deepmd.dpmodel.common import ( + to_numpy_array, +) from deepmd.dpmodel.utils import ( AtomExcludeMask, FittingNet, @@ -283,11 +287,11 @@ def serialize(self) -> dict: "exclude_types": self.exclude_types, "nets": self.nets.serialize(), "@variables": { - "bias_atom_e": self.bias_atom_e, - "fparam_avg": self.fparam_avg, - "fparam_inv_std": self.fparam_inv_std, - "aparam_avg": self.aparam_avg, - "aparam_inv_std": self.aparam_inv_std, + "bias_atom_e": to_numpy_array(self.bias_atom_e), + "fparam_avg": to_numpy_array(self.fparam_avg), + "fparam_inv_std": to_numpy_array(self.fparam_inv_std), + "aparam_avg": to_numpy_array(self.aparam_avg), + "aparam_inv_std": to_numpy_array(self.aparam_inv_std), }, "type_map": self.type_map, # not supported @@ -344,6 +348,7 @@ def _call_common( The atomic parameter. shape: nf x nloc x nap. nap being `numb_aparam` """ + xp = array_api_compat.array_namespace(descriptor, atype) nf, nloc, nd = descriptor.shape net_dim_out = self._net_out_dim() # check input dim @@ -359,7 +364,7 @@ def _call_common( # we consider it as always zero for convenience. # Needs a compute_input_stats for vaccum passed from the # descriptor. - xx_zeros = np.zeros_like(xx) + xx_zeros = xp.zeros_like(xx) else: xx_zeros = None # check fparam dim, concate to input descriptor @@ -371,13 +376,15 @@ def _call_common( "which is not consistent with {self.numb_fparam}.", ) fparam = (fparam - self.fparam_avg) * self.fparam_inv_std - fparam = np.tile(fparam.reshape([nf, 1, self.numb_fparam]), [1, nloc, 1]) - xx = np.concatenate( + fparam = xp.tile( + xp.reshape(fparam, [nf, 1, self.numb_fparam]), (1, nloc, 1) + ) + xx = xp.concat( [xx, fparam], axis=-1, ) if xx_zeros is not None: - xx_zeros = np.concatenate( + xx_zeros = xp.concat( [xx_zeros, fparam], axis=-1, ) @@ -389,24 +396,24 @@ def _call_common( "get an input aparam of dim {aparam.shape[-1]}, ", "which is not consistent with {self.numb_aparam}.", ) - aparam = aparam.reshape([nf, nloc, self.numb_aparam]) + aparam = xp.reshape(aparam, [nf, nloc, self.numb_aparam]) aparam = (aparam - self.aparam_avg) * self.aparam_inv_std - xx = np.concatenate( + xx = xp.concat( [xx, aparam], axis=-1, ) if xx_zeros is not None: - xx_zeros = np.concatenate( + xx_zeros = xp.concat( [xx_zeros, aparam], axis=-1, ) # calcualte the prediction if not self.mixed_types: - outs = np.zeros([nf, nloc, net_dim_out]) # pylint: disable=no-explicit-dtype + outs = xp.zeros([nf, nloc, net_dim_out]) # pylint: disable=no-explicit-dtype for type_i in range(self.ntypes): - mask = np.tile( - (atype == type_i).reshape([nf, nloc, 1]), [1, 1, net_dim_out] + mask = xp.tile( + xp.reshape((atype == type_i), [nf, nloc, 1]), (1, 1, net_dim_out) ) atom_property = self.nets[(type_i,)](xx) if self.remove_vaccum_contribution is not None and not ( @@ -415,15 +422,18 @@ def _call_common( ): assert xx_zeros is not None atom_property -= self.nets[(type_i,)](xx_zeros) - atom_property = atom_property + self.bias_atom_e[type_i] - atom_property = atom_property * mask + atom_property = atom_property + self.bias_atom_e[type_i, ...] + atom_property = atom_property * xp.astype(mask, atom_property.dtype) outs = outs + atom_property # Shape is [nframes, natoms[0], 1] else: - outs = self.nets[()](xx) + self.bias_atom_e[atype] + outs = self.nets[()](xx) + xp.reshape( + xp.take(self.bias_atom_e, xp.reshape(atype, [-1]), axis=0), + [nf, nloc, net_dim_out], + ) if xx_zeros is not None: outs -= self.nets[()](xx_zeros) # nf x nloc exclude_mask = self.emask.build_type_exclude_mask(atype) # nf x nloc x nod - outs = outs * exclude_mask[:, :, None] + outs = outs * xp.astype(exclude_mask[:, :, None], outs.dtype) return {self.var_name: outs} diff --git a/deepmd/dpmodel/utils/exclude_mask.py b/deepmd/dpmodel/utils/exclude_mask.py index 5469e66d97..b09a9b3e47 100644 --- a/deepmd/dpmodel/utils/exclude_mask.py +++ b/deepmd/dpmodel/utils/exclude_mask.py @@ -18,12 +18,12 @@ def __init__( ): self.ntypes = ntypes self.exclude_types = exclude_types - self.type_mask = np.array( + type_mask = np.array( [1 if tt_i not in self.exclude_types else 0 for tt_i in range(ntypes)], dtype=np.int32, ) # (ntypes) - self.type_mask = self.type_mask.reshape([-1]) + self.type_mask = type_mask.reshape([-1]) def get_exclude_types(self): return self.exclude_types @@ -52,7 +52,9 @@ def build_type_exclude_mask( """ xp = array_api_compat.array_namespace(atype) nf, natom = atype.shape - return xp.reshape(self.type_mask[atype], (nf, natom)) + return xp.reshape( + xp.take(self.type_mask, xp.reshape(atype, [-1]), axis=0), (nf, natom) + ) class PairExcludeMask: diff --git a/deepmd/jax/fitting/__init__.py b/deepmd/jax/fitting/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/deepmd/jax/fitting/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/deepmd/jax/fitting/fitting.py b/deepmd/jax/fitting/fitting.py new file mode 100644 index 0000000000..27ad791db9 --- /dev/null +++ b/deepmd/jax/fitting/fitting.py @@ -0,0 +1,39 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.dpmodel.fitting.ener_fitting import EnergyFittingNet as EnergyFittingNetDP +from deepmd.jax.common import ( + flax_module, + to_jax_array, +) +from deepmd.jax.utils.exclude_mask import ( + AtomExcludeMask, +) +from deepmd.jax.utils.network import ( + NetworkCollection, +) + + +def setattr_for_general_fitting(name: str, value: Any) -> Any: + if name in { + "bias_atom_e", + "fparam_avg", + "fparam_inv_std", + "aparam_avg", + "aparam_inv_std", + }: + value = to_jax_array(value) + elif name == "emask": + value = AtomExcludeMask(value.ntypes, value.exclude_types) + elif name == "nets": + value = NetworkCollection.deserialize(value.serialize()) + return value + + +@flax_module +class EnergyFittingNet(EnergyFittingNetDP): + def __setattr__(self, name: str, value: Any) -> None: + value = setattr_for_general_fitting(name, value) + return super().__setattr__(name, value) diff --git a/deepmd/jax/utils/exclude_mask.py b/deepmd/jax/utils/exclude_mask.py index cac4cee092..a6cf210f94 100644 --- a/deepmd/jax/utils/exclude_mask.py +++ b/deepmd/jax/utils/exclude_mask.py @@ -3,6 +3,7 @@ Any, ) +from deepmd.dpmodel.utils.exclude_mask import AtomExcludeMask as AtomExcludeMaskDP from deepmd.dpmodel.utils.exclude_mask import PairExcludeMask as PairExcludeMaskDP from deepmd.jax.common import ( flax_module, @@ -10,6 +11,14 @@ ) +@flax_module +class AtomExcludeMask(AtomExcludeMaskDP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"type_mask"}: + value = to_jax_array(value) + return super().__setattr__(name, value) + + @flax_module class PairExcludeMask(PairExcludeMaskDP): def __setattr__(self, name: str, value: Any) -> None: diff --git a/source/tests/array_api_strict/fitting/__init__.py b/source/tests/array_api_strict/fitting/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/source/tests/array_api_strict/fitting/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/source/tests/array_api_strict/fitting/fitting.py b/source/tests/array_api_strict/fitting/fitting.py new file mode 100644 index 0000000000..2e6bd9fe25 --- /dev/null +++ b/source/tests/array_api_strict/fitting/fitting.py @@ -0,0 +1,38 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.dpmodel.fitting.ener_fitting import EnergyFittingNet as EnergyFittingNetDP + +from ..common import ( + to_array_api_strict_array, +) +from ..utils.exclude_mask import ( + AtomExcludeMask, +) +from ..utils.network import ( + NetworkCollection, +) + + +def setattr_for_general_fitting(name: str, value: Any) -> Any: + if name in { + "bias_atom_e", + "fparam_avg", + "fparam_inv_std", + "aparam_avg", + "aparam_inv_std", + }: + value = to_array_api_strict_array(value) + elif name == "emask": + value = AtomExcludeMask(value.ntypes, value.exclude_types) + elif name == "nets": + value = NetworkCollection.deserialize(value.serialize()) + return value + + +class EnergyFittingNet(EnergyFittingNetDP): + def __setattr__(self, name: str, value: Any) -> None: + value = setattr_for_general_fitting(name, value) + return super().__setattr__(name, value) diff --git a/source/tests/array_api_strict/utils/exclude_mask.py b/source/tests/array_api_strict/utils/exclude_mask.py index 06f2e94b52..7f5c29e0a8 100644 --- a/source/tests/array_api_strict/utils/exclude_mask.py +++ b/source/tests/array_api_strict/utils/exclude_mask.py @@ -3,6 +3,7 @@ Any, ) +from deepmd.dpmodel.utils.exclude_mask import AtomExcludeMask as AtomExcludeMaskDP from deepmd.dpmodel.utils.exclude_mask import PairExcludeMask as PairExcludeMaskDP from ..common import ( @@ -10,6 +11,13 @@ ) +class AtomExcludeMask(AtomExcludeMaskDP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"type_mask"}: + value = to_array_api_strict_array(value) + return super().__setattr__(name, value) + + class PairExcludeMask(PairExcludeMaskDP): def __setattr__(self, name: str, value: Any) -> None: if name in {"type_mask"}: diff --git a/source/tests/consistent/fitting/test_ener.py b/source/tests/consistent/fitting/test_ener.py index ac4f7ae543..ba2be1d86b 100644 --- a/source/tests/consistent/fitting/test_ener.py +++ b/source/tests/consistent/fitting/test_ener.py @@ -12,6 +12,8 @@ ) from ..common import ( + INSTALLED_ARRAY_API_STRICT, + INSTALLED_JAX, INSTALLED_PT, INSTALLED_TF, CommonTest, @@ -36,6 +38,22 @@ fitting_ener, ) +if INSTALLED_JAX: + from deepmd.jax.env import ( + jnp, + ) + from deepmd.jax.fitting.fitting import EnergyFittingNet as EnerFittingJAX +else: + EnerFittingJAX = object +if INSTALLED_ARRAY_API_STRICT: + import array_api_strict + + from ...array_api_strict.fitting.fitting import ( + EnergyFittingNet as EnerFittingStrict, + ) +else: + EnerFittingStrict = None + @parameterized( (True, False), # resnet_dt @@ -74,9 +92,25 @@ def skip_pt(self) -> bool: ) = self.param return CommonTest.skip_pt + skip_jax = not INSTALLED_JAX + + @property + def skip_array_api_strict(self) -> bool: + ( + resnet_dt, + precision, + mixed_types, + numb_fparam, + atom_ener, + ) = self.param + # TypeError: The array_api_strict namespace does not support the dtype 'bfloat16' + return not INSTALLED_ARRAY_API_STRICT or precision == "bfloat16" + tf_class = EnerFittingTF dp_class = EnerFittingDP pt_class = EnerFittingPT + jax_class = EnerFittingJAX + array_api_strict_class = EnerFittingStrict args = fitting_ener() def setUp(self): @@ -157,6 +191,39 @@ def eval_dp(self, dp_obj: Any) -> Any: fparam=self.fparam if numb_fparam else None, )["energy"] + def eval_jax(self, jax_obj: Any) -> Any: + ( + resnet_dt, + precision, + mixed_types, + numb_fparam, + atom_ener, + ) = self.param + return np.asarray( + jax_obj( + jnp.asarray(self.inputs), + jnp.asarray(self.atype.reshape(1, -1)), + fparam=jnp.asarray(self.fparam) if numb_fparam else None, + )["energy"] + ) + + def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: + array_api_strict.set_array_api_strict_flags(api_version="2023.12") + ( + resnet_dt, + precision, + mixed_types, + numb_fparam, + atom_ener, + ) = self.param + return np.asarray( + array_api_strict_obj( + array_api_strict.asarray(self.inputs), + array_api_strict.asarray(self.atype.reshape(1, -1)), + fparam=array_api_strict.asarray(self.fparam) if numb_fparam else None, + )["energy"] + ) + def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: if backend == self.RefBackend.TF: # shape is not same