diff --git a/deepmd/dpmodel/descriptor/dpa1.py b/deepmd/dpmodel/descriptor/dpa1.py index 37b677cf23..2ce05bd160 100644 --- a/deepmd/dpmodel/descriptor/dpa1.py +++ b/deepmd/dpmodel/descriptor/dpa1.py @@ -899,6 +899,7 @@ def call( exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext) # nfnl x nnei exclude_mask = xp.reshape(exclude_mask, (nf * nloc, nnei)) + exclude_mask = xp.astype(exclude_mask, xp.bool) # nfnl x nnei nlist = xp.reshape(nlist, (nf * nloc, nnei)) nlist = xp.where(exclude_mask, nlist, xp.full_like(nlist, -1)) diff --git a/deepmd/dpmodel/descriptor/repformers.py b/deepmd/dpmodel/descriptor/repformers.py index ae6b5de511..e15a20926f 100644 --- a/deepmd/dpmodel/descriptor/repformers.py +++ b/deepmd/dpmodel/descriptor/repformers.py @@ -393,6 +393,7 @@ def call( ): xp = array_api_compat.array_namespace(nlist, coord_ext, atype_ext) exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext) + exclude_mask = xp.astype(exclude_mask, xp.bool) nlist = xp.where(exclude_mask, nlist, xp.full_like(nlist, -1)) # nf x nloc x nnei x 4 dmatrix, diff, sw = self.env_mat.call( diff --git a/deepmd/dpmodel/descriptor/se_t_tebd.py b/deepmd/dpmodel/descriptor/se_t_tebd.py index 570f9a47e8..1b0f44ec97 100644 --- a/deepmd/dpmodel/descriptor/se_t_tebd.py +++ b/deepmd/dpmodel/descriptor/se_t_tebd.py @@ -682,6 +682,7 @@ def call( exclude_mask = xp.reshape(exclude_mask, (nf * nloc, nnei)) # nfnl x nnei nlist = xp.reshape(nlist, (nf * nloc, nnei)) + exclude_mask = xp.astype(exclude_mask, xp.bool) nlist = xp.where(exclude_mask, nlist, xp.full_like(nlist, -1)) # nfnl x nnei nlist_mask = nlist != -1 diff --git a/deepmd/dpmodel/fitting/general_fitting.py b/deepmd/dpmodel/fitting/general_fitting.py index 388932f297..7342663141 100644 --- a/deepmd/dpmodel/fitting/general_fitting.py +++ b/deepmd/dpmodel/fitting/general_fitting.py @@ -488,6 +488,7 @@ def _call_common( ) # nf x nloc exclude_mask = self.emask.build_type_exclude_mask(atype) + exclude_mask = xp.astype(exclude_mask, xp.bool) # nf x nloc x nod outs = xp.where(exclude_mask[:, :, None], outs, xp.zeros_like(outs)) return {self.var_name: outs}