diff --git a/deepmd/pd/train/training.py b/deepmd/pd/train/training.py index d72c270667..bb0a8987b8 100644 --- a/deepmd/pd/train/training.py +++ b/deepmd/pd/train/training.py @@ -607,6 +607,27 @@ def warm_up_linear(step, warmup_steps): ) backend = "CINN" if CINN else None + # NOTE: This is a trick to decide the right input_spec for wrapper.forward + _, label_dict, _ = self.get_data(is_train=True) + + # Define specification templates + spec_templates = { + "find_box": np.float32(1.0), + "find_coord": np.float32(1.0), + "find_numb_copy": np.float32(0.0), + "numb_copy": static.InputSpec([1, 1], "int64", name="numb_copy"), + "find_energy": np.float32(1.0), + "energy": static.InputSpec([1, 1], "float64", name="energy"), + "find_force": np.float32(1.0), + "force": static.InputSpec([1, -1, 3], "float64", name="force"), + "find_virial": np.float32(0.0), + "virial": static.InputSpec([1, 9], "float64", name="virial"), + "natoms": static.InputSpec([1, -1], "int32", name="natoms"), + } + # Build spec only for keys present in sample data + label_dict_spec = { + k: spec_templates[k] for k in label_dict.keys() if k in spec_templates + } self.wrapper.forward = jit.to_static( backend=backend, input_spec=[ @@ -615,19 +636,7 @@ def warm_up_linear(step, warmup_steps): None, # spin static.InputSpec([1, 9], "float64", name="box"), # box static.InputSpec([], "float64", name="cur_lr"), # cur_lr - { - "find_box": np.float32(1.0), - "find_coord": np.float32(1.0), - "find_numb_copy": np.float32(0.0), - "numb_copy": static.InputSpec( - [1, 1], "int64", name="numb_copy" - ), - "find_energy": np.float32(1.0), - "energy": static.InputSpec([1, 1], "float64", name="energy"), - "find_force": np.float32(1.0), - "force": static.InputSpec([1, -1, 3], "float64", name="force"), - "natoms": static.InputSpec([1, -1], "int32", name="natoms"), - }, # label, + label_dict_spec, # label, # None, # task_key # False, # inference_only # False, # do_atomic_virial