@@ -52,7 +52,20 @@ def get_leading_dims(
5252 vv : np .ndarray ,
5353 vdef : OutputVariableDef ,
5454):
55- """Get the dimensions of nf x nloc."""
55+ """Get the dimensions of nf x nloc.
56+
57+ Parameters
58+ ----------
59+ vv : np.ndarray
60+ The input array from which to compute the leading dimensions.
61+ vdef : OutputVariableDef
62+ The output variable definition containing the shape to exclude from `vv`.
63+
64+ Returns
65+ -------
66+ list
67+ A list of leading dimensions of `vv`, excluding the last `len(vdef.shape)` dimensions.
68+ """
5669 vshape = vv .shape
5770 return list (vshape [: (len (vshape ) - len (vdef .shape ))])
5871
@@ -76,11 +89,11 @@ def communicate_extended_output(
7689 if vdef .reducible :
7790 kk_redu = get_reduce_name (kk )
7891 new_ret [kk_redu ] = model_ret [kk_redu ]
92+ kk_derv_r , kk_derv_c = get_deriv_name (kk )
93+ mldims = list (mapping .shape )
94+ vldims = get_leading_dims (vv , vdef )
7995 if vdef .r_differentiable :
80- kk_derv_r , kk_derv_c = get_deriv_name (kk )
8196 if model_ret [kk_derv_r ] is not None :
82- mldims = list (mapping .shape )
83- vldims = get_leading_dims (vv , vdef )
8497 derv_r_ext_dims = list (vdef .shape ) + [3 ] # noqa:RUF005
8598 mapping = xp .reshape (mapping , (mldims + [1 ] * len (derv_r_ext_dims )))
8699 mapping = xp .tile (mapping , [1 ] * len (mldims ) + derv_r_ext_dims )
@@ -109,9 +122,35 @@ def communicate_extended_output(
109122 new_ret [kk_derv_r ] = None
110123 if vdef .c_differentiable :
111124 assert vdef .r_differentiable
112- kk_derv_r , kk_derv_c = get_deriv_name (kk )
113- new_ret [kk_derv_c ] = None
114- new_ret [kk_derv_c + "_redu" ] = None
125+ if model_ret [kk_derv_c ] is not None :
126+ derv_c_ext_dims = list (vdef .shape ) + [9 ] # noqa:RUF005
127+ mapping = xp .tile (
128+ mapping , [1 ] * (len (mldims ) + len (vdef .shape )) + [3 ]
129+ )
130+ virial = xp .zeros (
131+ vldims + derv_c_ext_dims , dtype = vv .dtype , device = vv .device
132+ )
133+ # jax only
134+ if array_api_compat .is_jax_array (virial ):
135+ from deepmd .jax .env import (
136+ jnp ,
137+ )
138+
139+ v_idx = xp .arange (virial .size , dtype = xp .int64 ).reshape (
140+ virial .shape
141+ )
142+ new_idx = jnp .take_along_axis (v_idx , mapping , axis = 1 ).ravel ()
143+ v_shape = virial .shape
144+ virial = virial .ravel ()
145+ virial = virial .at [new_idx ].add (model_ret [kk_derv_c ].ravel ())
146+ virial = virial .reshape (v_shape )
147+ else :
148+ raise NotImplementedError ("Only JAX arrays are supported." )
149+ new_ret [kk_derv_c ] = virial
150+ new_ret [kk_derv_c + "_redu" ] = xp .sum (new_ret [kk_derv_c ], axis = 1 )
151+ else :
152+ new_ret [kk_derv_c ] = None
153+ new_ret [kk_derv_c + "_redu" ] = None
115154 if not do_atomic_virial :
116155 # pop atomic virial, because it is not correctly calculated.
117156 new_ret .pop (kk_derv_c )
0 commit comments