From d533da32d01f08dc64b71951305e7fb7ee4132ab Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 13 May 2025 00:26:33 +0800 Subject: [PATCH] fix(tf): always use float64 for the global tensor Fix #4734. Signed-off-by: Jinzhe Zeng --- deepmd/tf/model/tensor.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/deepmd/tf/model/tensor.py b/deepmd/tf/model/tensor.py index 8f361ce281..1e960907ef 100644 --- a/deepmd/tf/model/tensor.py +++ b/deepmd/tf/model/tensor.py @@ -6,6 +6,7 @@ from deepmd.tf.env import ( MODEL_VERSION, + global_cvt_2_ener_float, tf, ) from deepmd.tf.utils.type_embed import ( @@ -173,7 +174,7 @@ def build( if "global" not in self.model_type: gname = "global_" + self.model_type atom_out = tf.reshape(output, [-1, natomsel, nout]) - global_out = tf.reduce_sum(atom_out, axis=1) + global_out = tf.reduce_sum(global_cvt_2_ener_float(atom_out), axis=1) global_out = tf.reshape(global_out, [-1, nout], name="o_" + gname + suffix) out_cpnts = tf.split(atom_out, nout, axis=-1)