From 81182318611d3e65b716b61b78bb6daa73d17c14 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Tue, 10 Jun 2025 21:31:24 +0800 Subject: [PATCH 1/4] update adaptive CINN --- deepmd/pd/train/training.py | 43 ++++++++++++++++++++++++++----------- 1 file changed, 30 insertions(+), 13 deletions(-) diff --git a/deepmd/pd/train/training.py b/deepmd/pd/train/training.py index d72c270667..07cd738a1c 100644 --- a/deepmd/pd/train/training.py +++ b/deepmd/pd/train/training.py @@ -607,6 +607,35 @@ def warm_up_linear(step, warmup_steps): ) backend = "CINN" if CINN else None + + # NOTE: This is a trick to decide the right input_spec for wrapper.forward + _, label_dict, _ = self.get_data(is_train=True, task_key="Default") + label_dict_spec = { + "find_box": np.float32(1.0), + "find_coord": np.float32(1.0), + "find_numb_copy": np.float32(0.0), + "numb_copy": static.InputSpec([1, 1], "int64", name="numb_copy"), + "find_energy": np.float32(1.0), + "energy": static.InputSpec([1, 1], "float64", name="energy"), + "find_force": np.float32(1.0), + "force": static.InputSpec([1, -1, 3], "float64", name="force"), + "find_virial": np.float32(0.0), + "virial": static.InputSpec([1, 9], "float64", name="virial"), + "natoms": static.InputSpec([1, -1], "int32", name="natoms"), + } + if "virial" not in label_dict: + label_dict_spec.pop("virial") + if "find_virial" not in label_dict: + label_dict_spec.pop("find_virial") + if "energy" not in label_dict: + label_dict_spec.pop("energy") + if "find_energy" not in label_dict: + label_dict_spec.pop("find_energy") + if "force" not in label_dict: + label_dict_spec.pop("force") + if "find_force" not in label_dict: + label_dict_spec.pop("find_force") + self.wrapper.forward = jit.to_static( backend=backend, input_spec=[ @@ -615,19 +644,7 @@ def warm_up_linear(step, warmup_steps): None, # spin static.InputSpec([1, 9], "float64", name="box"), # box static.InputSpec([], "float64", name="cur_lr"), # cur_lr - { - "find_box": np.float32(1.0), - "find_coord": np.float32(1.0), - "find_numb_copy": np.float32(0.0), - "numb_copy": static.InputSpec( - [1, 1], "int64", name="numb_copy" - ), - "find_energy": np.float32(1.0), - "energy": static.InputSpec([1, 1], "float64", name="energy"), - "find_force": np.float32(1.0), - "force": static.InputSpec([1, -1, 3], "float64", name="force"), - "natoms": static.InputSpec([1, -1], "int32", name="natoms"), - }, # label, + label_dict_spec, # label, # None, # task_key # False, # inference_only # False, # do_atomic_virial From d264faef4fccc8d20d86540844ab8d251558b6ed Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Fri, 13 Jun 2025 11:18:24 +0800 Subject: [PATCH 2/4] Update deepmd/pd/train/training.py Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Signed-off-by: HydrogenSulfate <490868991@qq.com> --- deepmd/pd/train/training.py | 24 ++++++++---------------- 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/deepmd/pd/train/training.py b/deepmd/pd/train/training.py index 07cd738a1c..984535b5d7 100644 --- a/deepmd/pd/train/training.py +++ b/deepmd/pd/train/training.py @@ -607,10 +607,13 @@ def warm_up_linear(step, warmup_steps): ) backend = "CINN" if CINN else None - # NOTE: This is a trick to decide the right input_spec for wrapper.forward - _, label_dict, _ = self.get_data(is_train=True, task_key="Default") - label_dict_spec = { + # Use appropriate task_key for multi-task scenarios + sample_task_key = self.model_keys[0] if self.multi_task else "Default" + _, label_dict, _ = self.get_data(is_train=True, task_key=sample_task_key) + + # Define specification templates + spec_templates = { "find_box": np.float32(1.0), "find_coord": np.float32(1.0), "find_numb_copy": np.float32(0.0), @@ -623,19 +626,8 @@ def warm_up_linear(step, warmup_steps): "virial": static.InputSpec([1, 9], "float64", name="virial"), "natoms": static.InputSpec([1, -1], "int32", name="natoms"), } - if "virial" not in label_dict: - label_dict_spec.pop("virial") - if "find_virial" not in label_dict: - label_dict_spec.pop("find_virial") - if "energy" not in label_dict: - label_dict_spec.pop("energy") - if "find_energy" not in label_dict: - label_dict_spec.pop("find_energy") - if "force" not in label_dict: - label_dict_spec.pop("force") - if "find_force" not in label_dict: - label_dict_spec.pop("find_force") - + # Build spec only for keys present in sample data + label_dict_spec = {k: spec_templates[k] for k in label_dict.keys() if k in spec_templates} self.wrapper.forward = jit.to_static( backend=backend, input_spec=[ From 6528a85b31bf178fb158845800ec8abc0b39bb88 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 13 Jun 2025 03:20:01 +0000 Subject: [PATCH 3/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/pd/train/training.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/deepmd/pd/train/training.py b/deepmd/pd/train/training.py index 984535b5d7..ca9e28726c 100644 --- a/deepmd/pd/train/training.py +++ b/deepmd/pd/train/training.py @@ -627,7 +627,9 @@ def warm_up_linear(step, warmup_steps): "natoms": static.InputSpec([1, -1], "int32", name="natoms"), } # Build spec only for keys present in sample data - label_dict_spec = {k: spec_templates[k] for k in label_dict.keys() if k in spec_templates} + label_dict_spec = { + k: spec_templates[k] for k in label_dict.keys() if k in spec_templates + } self.wrapper.forward = jit.to_static( backend=backend, input_spec=[ From 3a6438e973c1de5c94e937cf839679005ab3090e Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Fri, 13 Jun 2025 11:34:37 +0800 Subject: [PATCH 4/4] fix --- deepmd/pd/train/training.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/deepmd/pd/train/training.py b/deepmd/pd/train/training.py index ca9e28726c..bb0a8987b8 100644 --- a/deepmd/pd/train/training.py +++ b/deepmd/pd/train/training.py @@ -608,9 +608,7 @@ def warm_up_linear(step, warmup_steps): backend = "CINN" if CINN else None # NOTE: This is a trick to decide the right input_spec for wrapper.forward - # Use appropriate task_key for multi-task scenarios - sample_task_key = self.model_keys[0] if self.multi_task else "Default" - _, label_dict, _ = self.get_data(is_train=True, task_key=sample_task_key) + _, label_dict, _ = self.get_data(is_train=True) # Define specification templates spec_templates = {