Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 29 additions & 19 deletions deepmd/dpmodel/fitting/general_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,16 @@
Union,
)

import array_api_compat
Comment thread
njzjz marked this conversation as resolved.
import numpy as np

from deepmd.dpmodel import (
DEFAULT_PRECISION,
NativeOP,
)
from deepmd.dpmodel.common import (
to_numpy_array,
)
from deepmd.dpmodel.utils import (
AtomExcludeMask,
FittingNet,
Expand Down Expand Up @@ -283,11 +287,11 @@ def serialize(self) -> dict:
"exclude_types": self.exclude_types,
"nets": self.nets.serialize(),
"@variables": {
"bias_atom_e": self.bias_atom_e,
"fparam_avg": self.fparam_avg,
"fparam_inv_std": self.fparam_inv_std,
"aparam_avg": self.aparam_avg,
"aparam_inv_std": self.aparam_inv_std,
"bias_atom_e": to_numpy_array(self.bias_atom_e),
"fparam_avg": to_numpy_array(self.fparam_avg),
"fparam_inv_std": to_numpy_array(self.fparam_inv_std),
"aparam_avg": to_numpy_array(self.aparam_avg),
"aparam_inv_std": to_numpy_array(self.aparam_inv_std),
},
"type_map": self.type_map,
# not supported
Expand Down Expand Up @@ -344,6 +348,7 @@ def _call_common(
The atomic parameter. shape: nf x nloc x nap. nap being `numb_aparam`

"""
xp = array_api_compat.array_namespace(descriptor, atype)
nf, nloc, nd = descriptor.shape
net_dim_out = self._net_out_dim()
# check input dim
Expand All @@ -359,7 +364,7 @@ def _call_common(
# we consider it as always zero for convenience.
# Needs a compute_input_stats for vaccum passed from the
# descriptor.
xx_zeros = np.zeros_like(xx)
xx_zeros = xp.zeros_like(xx)
else:
xx_zeros = None
# check fparam dim, concate to input descriptor
Expand All @@ -371,13 +376,15 @@ def _call_common(
"which is not consistent with {self.numb_fparam}.",
)
fparam = (fparam - self.fparam_avg) * self.fparam_inv_std
fparam = np.tile(fparam.reshape([nf, 1, self.numb_fparam]), [1, nloc, 1])
xx = np.concatenate(
fparam = xp.tile(
xp.reshape(fparam, [nf, 1, self.numb_fparam]), (1, nloc, 1)
)
Comment thread
njzjz marked this conversation as resolved.
xx = xp.concat(
Comment thread
njzjz marked this conversation as resolved.
[xx, fparam],
axis=-1,
)
if xx_zeros is not None:
xx_zeros = np.concatenate(
xx_zeros = xp.concat(
[xx_zeros, fparam],
axis=-1,
)
Expand All @@ -389,24 +396,24 @@ def _call_common(
"get an input aparam of dim {aparam.shape[-1]}, ",
"which is not consistent with {self.numb_aparam}.",
)
aparam = aparam.reshape([nf, nloc, self.numb_aparam])
aparam = xp.reshape(aparam, [nf, nloc, self.numb_aparam])
aparam = (aparam - self.aparam_avg) * self.aparam_inv_std
xx = np.concatenate(
xx = xp.concat(
[xx, aparam],
axis=-1,
)
if xx_zeros is not None:
xx_zeros = np.concatenate(
xx_zeros = xp.concat(
[xx_zeros, aparam],
axis=-1,
)

# calcualte the prediction
if not self.mixed_types:
outs = np.zeros([nf, nloc, net_dim_out]) # pylint: disable=no-explicit-dtype
outs = xp.zeros([nf, nloc, net_dim_out]) # pylint: disable=no-explicit-dtype
for type_i in range(self.ntypes):
mask = np.tile(
(atype == type_i).reshape([nf, nloc, 1]), [1, 1, net_dim_out]
mask = xp.tile(
xp.reshape((atype == type_i), [nf, nloc, 1]), (1, 1, net_dim_out)
)
atom_property = self.nets[(type_i,)](xx)
if self.remove_vaccum_contribution is not None and not (
Expand All @@ -415,15 +422,18 @@ def _call_common(
):
assert xx_zeros is not None
atom_property -= self.nets[(type_i,)](xx_zeros)
atom_property = atom_property + self.bias_atom_e[type_i]
atom_property = atom_property * mask
atom_property = atom_property + self.bias_atom_e[type_i, ...]
atom_property = atom_property * xp.astype(mask, atom_property.dtype)
Comment thread
njzjz marked this conversation as resolved.
outs = outs + atom_property # Shape is [nframes, natoms[0], 1]
else:
outs = self.nets[()](xx) + self.bias_atom_e[atype]
outs = self.nets[()](xx) + xp.reshape(
xp.take(self.bias_atom_e, xp.reshape(atype, [-1]), axis=0),
Comment thread
njzjz marked this conversation as resolved.
[nf, nloc, net_dim_out],
)
if xx_zeros is not None:
outs -= self.nets[()](xx_zeros)
# nf x nloc
exclude_mask = self.emask.build_type_exclude_mask(atype)
# nf x nloc x nod
outs = outs * exclude_mask[:, :, None]
outs = outs * xp.astype(exclude_mask[:, :, None], outs.dtype)
return {self.var_name: outs}
8 changes: 5 additions & 3 deletions deepmd/dpmodel/utils/exclude_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ def __init__(
):
self.ntypes = ntypes
self.exclude_types = exclude_types
self.type_mask = np.array(
type_mask = np.array(
[1 if tt_i not in self.exclude_types else 0 for tt_i in range(ntypes)],
dtype=np.int32,
)
# (ntypes)
self.type_mask = self.type_mask.reshape([-1])
self.type_mask = type_mask.reshape([-1])

def get_exclude_types(self):
return self.exclude_types
Expand Down Expand Up @@ -52,7 +52,9 @@ def build_type_exclude_mask(
"""
xp = array_api_compat.array_namespace(atype)
nf, natom = atype.shape
return xp.reshape(self.type_mask[atype], (nf, natom))
return xp.reshape(
xp.take(self.type_mask, xp.reshape(atype, [-1]), axis=0), (nf, natom)
)
Comment thread
njzjz marked this conversation as resolved.


class PairExcludeMask:
Expand Down
1 change: 1 addition & 0 deletions deepmd/jax/fitting/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
39 changes: 39 additions & 0 deletions deepmd/jax/fitting/fitting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Any,
)

from deepmd.dpmodel.fitting.ener_fitting import EnergyFittingNet as EnergyFittingNetDP
from deepmd.jax.common import (
flax_module,
to_jax_array,
)
from deepmd.jax.utils.exclude_mask import (
AtomExcludeMask,
)
from deepmd.jax.utils.network import (
NetworkCollection,
)


def setattr_for_general_fitting(name: str, value: Any) -> Any:
if name in {
"bias_atom_e",
"fparam_avg",
"fparam_inv_std",
"aparam_avg",
"aparam_inv_std",
}:
value = to_jax_array(value)
elif name == "emask":
value = AtomExcludeMask(value.ntypes, value.exclude_types)
elif name == "nets":
value = NetworkCollection.deserialize(value.serialize())
return value


@flax_module
class EnergyFittingNet(EnergyFittingNetDP):
def __setattr__(self, name: str, value: Any) -> None:
value = setattr_for_general_fitting(name, value)
return super().__setattr__(name, value)
Comment thread
njzjz marked this conversation as resolved.
9 changes: 9 additions & 0 deletions deepmd/jax/utils/exclude_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,22 @@
Any,
)

from deepmd.dpmodel.utils.exclude_mask import AtomExcludeMask as AtomExcludeMaskDP
from deepmd.dpmodel.utils.exclude_mask import PairExcludeMask as PairExcludeMaskDP
from deepmd.jax.common import (
flax_module,
to_jax_array,
)


@flax_module
class AtomExcludeMask(AtomExcludeMaskDP):
def __setattr__(self, name: str, value: Any) -> None:
if name in {"type_mask"}:
value = to_jax_array(value)
return super().__setattr__(name, value)


@flax_module
class PairExcludeMask(PairExcludeMaskDP):
def __setattr__(self, name: str, value: Any) -> None:
Expand Down
1 change: 1 addition & 0 deletions source/tests/array_api_strict/fitting/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
38 changes: 38 additions & 0 deletions source/tests/array_api_strict/fitting/fitting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Any,
)

from deepmd.dpmodel.fitting.ener_fitting import EnergyFittingNet as EnergyFittingNetDP

from ..common import (
to_array_api_strict_array,
)
from ..utils.exclude_mask import (
AtomExcludeMask,
)
from ..utils.network import (
NetworkCollection,
)


def setattr_for_general_fitting(name: str, value: Any) -> Any:
if name in {
"bias_atom_e",
"fparam_avg",
"fparam_inv_std",
"aparam_avg",
"aparam_inv_std",
}:
value = to_array_api_strict_array(value)
elif name == "emask":
value = AtomExcludeMask(value.ntypes, value.exclude_types)
elif name == "nets":
value = NetworkCollection.deserialize(value.serialize())
return value


class EnergyFittingNet(EnergyFittingNetDP):
def __setattr__(self, name: str, value: Any) -> None:
value = setattr_for_general_fitting(name, value)
return super().__setattr__(name, value)
8 changes: 8 additions & 0 deletions source/tests/array_api_strict/utils/exclude_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,21 @@
Any,
)

from deepmd.dpmodel.utils.exclude_mask import AtomExcludeMask as AtomExcludeMaskDP
from deepmd.dpmodel.utils.exclude_mask import PairExcludeMask as PairExcludeMaskDP

from ..common import (
to_array_api_strict_array,
)


class AtomExcludeMask(AtomExcludeMaskDP):
def __setattr__(self, name: str, value: Any) -> None:
if name in {"type_mask"}:
value = to_array_api_strict_array(value)
return super().__setattr__(name, value)


class PairExcludeMask(PairExcludeMaskDP):
def __setattr__(self, name: str, value: Any) -> None:
if name in {"type_mask"}:
Expand Down
67 changes: 67 additions & 0 deletions source/tests/consistent/fitting/test_ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
)

from ..common import (
INSTALLED_ARRAY_API_STRICT,
INSTALLED_JAX,
INSTALLED_PT,
INSTALLED_TF,
CommonTest,
Expand All @@ -36,6 +38,22 @@
fitting_ener,
)

if INSTALLED_JAX:
from deepmd.jax.env import (
jnp,
)
from deepmd.jax.fitting.fitting import EnergyFittingNet as EnerFittingJAX
else:
EnerFittingJAX = object
if INSTALLED_ARRAY_API_STRICT:
import array_api_strict

from ...array_api_strict.fitting.fitting import (
EnergyFittingNet as EnerFittingStrict,
)
else:
EnerFittingStrict = None


@parameterized(
(True, False), # resnet_dt
Expand Down Expand Up @@ -74,9 +92,25 @@ def skip_pt(self) -> bool:
) = self.param
return CommonTest.skip_pt

skip_jax = not INSTALLED_JAX

Comment thread
njzjz marked this conversation as resolved.
@property
def skip_array_api_strict(self) -> bool:
(
resnet_dt,
precision,
mixed_types,
numb_fparam,
atom_ener,
) = self.param
# TypeError: The array_api_strict namespace does not support the dtype 'bfloat16'
return not INSTALLED_ARRAY_API_STRICT or precision == "bfloat16"

tf_class = EnerFittingTF
dp_class = EnerFittingDP
pt_class = EnerFittingPT
jax_class = EnerFittingJAX
array_api_strict_class = EnerFittingStrict
args = fitting_ener()

def setUp(self):
Expand Down Expand Up @@ -157,6 +191,39 @@ def eval_dp(self, dp_obj: Any) -> Any:
fparam=self.fparam if numb_fparam else None,
)["energy"]

def eval_jax(self, jax_obj: Any) -> Any:
(
resnet_dt,
precision,
mixed_types,
numb_fparam,
atom_ener,
) = self.param
return np.asarray(
jax_obj(
jnp.asarray(self.inputs),
jnp.asarray(self.atype.reshape(1, -1)),
fparam=jnp.asarray(self.fparam) if numb_fparam else None,
)["energy"]
)

Comment thread
njzjz marked this conversation as resolved.
def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any:
array_api_strict.set_array_api_strict_flags(api_version="2023.12")
(
resnet_dt,
precision,
mixed_types,
numb_fparam,
atom_ener,
) = self.param
return np.asarray(
array_api_strict_obj(
array_api_strict.asarray(self.inputs),
array_api_strict.asarray(self.atype.reshape(1, -1)),
fparam=array_api_strict.asarray(self.fparam) if numb_fparam else None,
)["energy"]
)

Comment thread
njzjz marked this conversation as resolved.
def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]:
if backend == self.RefBackend.TF:
# shape is not same
Expand Down