Skip to content

Commit 004b89a

Browse files
committed
support virial
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
1 parent d0b576f commit 004b89a

4 files changed

Lines changed: 72 additions & 13 deletions

File tree

deepmd/dpmodel/model/transform_output.py

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,20 @@ def get_leading_dims(
5252
vv: np.ndarray,
5353
vdef: OutputVariableDef,
5454
):
55-
"""Get the dimensions of nf x nloc."""
55+
"""Get the dimensions of nf x nloc.
56+
57+
Parameters
58+
----------
59+
vv : np.ndarray
60+
The input array from which to compute the leading dimensions.
61+
vdef : OutputVariableDef
62+
The output variable definition containing the shape to exclude from `vv`.
63+
64+
Returns
65+
-------
66+
list
67+
A list of leading dimensions of `vv`, excluding the last `len(vdef.shape)` dimensions.
68+
"""
5669
vshape = vv.shape
5770
return list(vshape[: (len(vshape) - len(vdef.shape))])
5871

@@ -76,11 +89,11 @@ def communicate_extended_output(
7689
if vdef.reducible:
7790
kk_redu = get_reduce_name(kk)
7891
new_ret[kk_redu] = model_ret[kk_redu]
92+
kk_derv_r, kk_derv_c = get_deriv_name(kk)
93+
mldims = list(mapping.shape)
94+
vldims = get_leading_dims(vv, vdef)
7995
if vdef.r_differentiable:
80-
kk_derv_r, kk_derv_c = get_deriv_name(kk)
8196
if model_ret[kk_derv_r] is not None:
82-
mldims = list(mapping.shape)
83-
vldims = get_leading_dims(vv, vdef)
8497
derv_r_ext_dims = list(vdef.shape) + [3] # noqa:RUF005
8598
mapping = xp.reshape(mapping, (mldims + [1] * len(derv_r_ext_dims)))
8699
mapping = xp.tile(mapping, [1] * len(mldims) + derv_r_ext_dims)
@@ -109,9 +122,35 @@ def communicate_extended_output(
109122
new_ret[kk_derv_r] = None
110123
if vdef.c_differentiable:
111124
assert vdef.r_differentiable
112-
kk_derv_r, kk_derv_c = get_deriv_name(kk)
113-
new_ret[kk_derv_c] = None
114-
new_ret[kk_derv_c + "_redu"] = None
125+
if model_ret[kk_derv_c] is not None:
126+
derv_c_ext_dims = list(vdef.shape) + [9] # noqa:RUF005
127+
mapping = xp.tile(
128+
mapping, [1] * (len(mldims) + len(vdef.shape)) + [3]
129+
)
130+
virial = xp.zeros(
131+
vldims + derv_c_ext_dims, dtype=vv.dtype, device=vv.device
132+
)
133+
# jax only
134+
if array_api_compat.is_jax_array(virial):
135+
from deepmd.jax.env import (
136+
jnp,
137+
)
138+
139+
v_idx = xp.arange(virial.size, dtype=xp.int64).reshape(
140+
virial.shape
141+
)
142+
new_idx = jnp.take_along_axis(v_idx, mapping, axis=1).ravel()
143+
v_shape = virial.shape
144+
virial = virial.ravel()
145+
virial = virial.at[new_idx].add(model_ret[kk_derv_c].ravel())
146+
virial = virial.reshape(v_shape)
147+
else:
148+
raise NotImplementedError("Only JAX arrays are supported.")
149+
new_ret[kk_derv_c] = virial
150+
new_ret[kk_derv_c + "_redu"] = xp.sum(new_ret[kk_derv_c], axis=1)
151+
else:
152+
new_ret[kk_derv_c] = None
153+
new_ret[kk_derv_c + "_redu"] = None
115154
if not do_atomic_virial:
116155
# pop atomic virial, because it is not correctly calculated.
117156
new_ret.pop(kk_derv_c)

deepmd/jax/model/base_model.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def forward_common_atomic(
5353
size *= ii
5454

5555
split_ff = []
56+
split_vv = []
5657
for ss in range(size):
5758

5859
def eval_output(
@@ -76,13 +77,25 @@ def eval_output(
7677
fparam,
7778
aparam,
7879
)
80+
aviri = ffi[..., None] @ extended_coord[..., None, :]
7981
ffi = ffi[..., None, :]
8082
split_ff.append(ffi)
83+
aviri = aviri[..., None, :]
84+
split_vv.append(aviri)
8185
out_lead_shape = list(extended_coord.shape[:-1]) + vdef.shape
82-
ff = jnp.concatenate(split_ff, axis=-2).reshape(*out_lead_shape, 3)
86+
extended_force = jnp.concat(split_ff, axis=-2).reshape(
87+
*out_lead_shape, 3
88+
)
8389

84-
model_predict[kk_derv_r] = ff
90+
model_predict[kk_derv_r] = extended_force
8591
if vdef.c_differentiable:
8692
assert vdef.r_differentiable
87-
model_predict[kk_derv_c] = None
93+
extended_virial = jnp.concat(split_vv, axis=-2).reshape(
94+
*out_lead_shape, 9
95+
)
96+
# the correction sums to zero, which does not contribute to global virial
97+
if do_atomic_virial:
98+
raise NotImplementedError("Atomic virial is not implemented yet.")
99+
# to [...,3,3] -> [...,9]
100+
model_predict[kk_derv_c] = extended_virial
88101
return model_predict

source/tests/consistent/model/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def build_tf_model(self, obj, natoms, coords, atype, box, suffix):
5151
{},
5252
suffix=suffix,
5353
)
54-
return [ret["energy"], ret["atom_ener"], ret["force"]], {
54+
return [ret["energy"], ret["atom_ener"], ret["force"], ret["virial"]], {
5555
t_coord: coords,
5656
t_type: atype,
5757
t_natoms: natoms,

source/tests/consistent/model/test_ener.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,19 +211,26 @@ def eval_jax(self, jax_obj: Any) -> Any:
211211
def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]:
212212
# shape not matched. ravel...
213213
if backend is self.RefBackend.DP:
214-
return (ret["energy_redu"].ravel(), ret["energy"].ravel(), SKIP_FLAG)
214+
return (
215+
ret["energy_redu"].ravel(),
216+
ret["energy"].ravel(),
217+
SKIP_FLAG,
218+
SKIP_FLAG,
219+
)
215220
elif backend is self.RefBackend.PT:
216221
return (
217222
ret["energy"].ravel(),
218223
ret["atom_energy"].ravel(),
219224
ret["force"].ravel(),
225+
ret["virial"].ravel(),
220226
)
221227
elif backend is self.RefBackend.TF:
222-
return (ret[0].ravel(), ret[1].ravel(), ret[2].ravel())
228+
return (ret[0].ravel(), ret[1].ravel(), ret[2].ravel(), ret[3].ravel())
223229
elif backend is self.RefBackend.JAX:
224230
return (
225231
ret["energy_redu"].ravel(),
226232
ret["energy"].ravel(),
227233
ret["energy_derv_r"].ravel(),
234+
ret["energy_derv_c_redu"].ravel(),
228235
)
229236
raise ValueError(f"Unknown backend: {backend}")

0 commit comments

Comments
 (0)