Skip to content

Commit 4aa24ae

Browse files
committed
fix se_e2_r prod_force_virial zero input
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
1 parent dd56fe9 commit 4aa24ae

1 file changed

Lines changed: 2 additions & 1 deletion

File tree

deepmd/tf/descriptor/se_r.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -512,10 +512,11 @@ def prod_force_virial(
512512
"""
513513
[net_deriv] = tf.gradients(atom_ener, self.descrpt_reshape)
514514
tf.summary.histogram("net_derivative", net_deriv)
515+
nf = tf.shape(self.nlist)[0]
515516
net_deriv_reshape = tf.reshape(
516517
net_deriv,
517518
[
518-
np.asarray(-1, dtype=np.int64),
519+
nf,
519520
natoms[0] * np.asarray(self.ndescrpt, dtype=np.int64),
520521
],
521522
)

0 commit comments

Comments
 (0)