From babd11c0774da3716a79fc90814ef6b22cf2b5c8 Mon Sep 17 00:00:00 2001 From: Chun Cai Date: Thu, 15 May 2025 16:31:35 +0800 Subject: [PATCH 1/4] feat: use bfloat16 with `torch.autocast` on training --- deepmd/pt/model/descriptor/repflow_layer.py | 8 ++++++-- deepmd/pt/train/wrapper.py | 1 + deepmd/pt/utils/utils.py | 2 +- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/deepmd/pt/model/descriptor/repflow_layer.py b/deepmd/pt/model/descriptor/repflow_layer.py index f109109cfd..713e5e6857 100644 --- a/deepmd/pt/model/descriptor/repflow_layer.py +++ b/deepmd/pt/model/descriptor/repflow_layer.py @@ -424,7 +424,7 @@ def optim_angle_update( sub_edge_update_ik = torch.matmul(edge_ebd, sub_edge_ik) result_update = ( - bias + bias.to(sub_angle_update.dtype) + sub_node_update.unsqueeze(2).unsqueeze(3) + sub_edge_update_ij.unsqueeze(2) + sub_edge_update_ik.unsqueeze(3) @@ -463,7 +463,10 @@ def optim_edge_update( sub_edge_update = torch.matmul(edge_ebd, edge) result_update = ( - bias + sub_node_update.unsqueeze(2) + sub_edge_update + sub_node_ext_update + bias.to(sub_node_update.dtype) + + sub_node_update.unsqueeze(2) + + sub_edge_update + + sub_node_ext_update ) return result_update @@ -679,6 +682,7 @@ def forward( ) ) + a_sw.to(edge_angle_update.dtype) # nb x nloc x a_nnei x a_nnei x e_dim weighted_edge_angle_update = ( a_sw[..., None, None] * a_sw[..., None, :, None] * edge_angle_update diff --git a/deepmd/pt/train/wrapper.py b/deepmd/pt/train/wrapper.py index 9a2cbff295..29348363ea 100644 --- a/deepmd/pt/train/wrapper.py +++ b/deepmd/pt/train/wrapper.py @@ -136,6 +136,7 @@ def share_params(self, shared_links, resume=False) -> None: f"Shared params of {model_key_base}.{class_type_base} and {model_key_link}.{class_type_link}!" ) + @torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True) def forward( self, coord, diff --git a/deepmd/pt/utils/utils.py b/deepmd/pt/utils/utils.py index 85988e3523..4377939275 100644 --- a/deepmd/pt/utils/utils.py +++ b/deepmd/pt/utils/utils.py @@ -93,7 +93,7 @@ def forward(ctx, x, threshold, slope, const_val): ctx.threshold = threshold ctx.slope = slope ctx.const_val = const_val - return silut_forward_script(x, threshold, slope, const_val) + return silut_forward_script(x, threshold, slope, const_val).to(x.dtype) @staticmethod def backward(ctx, grad_output): From 9025d0b313ca870f290f18c5d39e75ce94c34e04 Mon Sep 17 00:00:00 2001 From: caic99 Date: Thu, 15 May 2025 09:19:01 +0000 Subject: [PATCH 2/4] remove unused statement --- deepmd/pt/model/descriptor/repflow_layer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/deepmd/pt/model/descriptor/repflow_layer.py b/deepmd/pt/model/descriptor/repflow_layer.py index 713e5e6857..cd2d79f52f 100644 --- a/deepmd/pt/model/descriptor/repflow_layer.py +++ b/deepmd/pt/model/descriptor/repflow_layer.py @@ -682,7 +682,6 @@ def forward( ) ) - a_sw.to(edge_angle_update.dtype) # nb x nloc x a_nnei x a_nnei x e_dim weighted_edge_angle_update = ( a_sw[..., None, None] * a_sw[..., None, :, None] * edge_angle_update From 00dcd58f072f129d0441dd941992f620af7974c7 Mon Sep 17 00:00:00 2001 From: Chun Cai Date: Mon, 19 May 2025 13:20:05 +0800 Subject: [PATCH 3/4] feat: enable configurable bfloat16 autocasting in ModelWrapper --- deepmd/pt/train/wrapper.py | 3 ++- deepmd/pt/utils/env.py | 2 ++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/deepmd/pt/train/wrapper.py b/deepmd/pt/train/wrapper.py index 29348363ea..a58face8f6 100644 --- a/deepmd/pt/train/wrapper.py +++ b/deepmd/pt/train/wrapper.py @@ -5,6 +5,7 @@ Union, ) +from deepmd.pt.utils.env import BF16_AUTOCAST import torch if torch.__version__.startswith("2"): @@ -136,7 +137,7 @@ def share_params(self, shared_links, resume=False) -> None: f"Shared params of {model_key_base}.{class_type_base} and {model_key_link}.{class_type_link}!" ) - @torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True) + @torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=BF16_AUTOCAST) def forward( self, coord, diff --git a/deepmd/pt/utils/env.py b/deepmd/pt/utils/env.py index 185bb1add3..63b0ecca07 100644 --- a/deepmd/pt/utils/env.py +++ b/deepmd/pt/utils/env.py @@ -35,6 +35,7 @@ CACHE_PER_SYS = 5 # keep at most so many sets per sys in memory ENERGY_BIAS_TRAINABLE = True CUSTOM_OP_USE_JIT = False +BF16_AUTOCAST = False PRECISION_DICT = { "float16": torch.float16, @@ -76,6 +77,7 @@ torch.set_num_threads(intra_nthreads) __all__ = [ + "BF16_AUTOCAST", "CACHE_PER_SYS", "CUSTOM_OP_USE_JIT", "DEFAULT_PRECISION", From 0a770e4efccd2bfff79336a43e3c182a8449ed99 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 19 May 2025 05:21:52 +0000 Subject: [PATCH 4/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/pt/train/wrapper.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/deepmd/pt/train/wrapper.py b/deepmd/pt/train/wrapper.py index a58face8f6..8666f642fc 100644 --- a/deepmd/pt/train/wrapper.py +++ b/deepmd/pt/train/wrapper.py @@ -5,9 +5,12 @@ Union, ) -from deepmd.pt.utils.env import BF16_AUTOCAST import torch +from deepmd.pt.utils.env import ( + BF16_AUTOCAST, +) + if torch.__version__.startswith("2"): import torch._dynamo