Skip to content

Commit 6ff6073

Browse files
Copilotnjzjz
andcommitted
fix: resolve TorchScript compilation errors in deepmd.pt descriptor modules
Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
1 parent 1415e50 commit 6ff6073

4 files changed

Lines changed: 23 additions & 13 deletions

File tree

deepmd/pt/model/descriptor/dpa1.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -717,10 +717,12 @@ def forward(
717717

718718
return (
719719
g1.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
720-
rot_mat.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
720+
rot_mat.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION)
721+
if rot_mat is not None
722+
else None,
721723
g2.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION) if g2 is not None else None,
722-
h2.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
723-
sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
724+
h2.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION) if h2 is not None else None,
725+
sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION) if sw is not None else None,
724726
)
725727

726728
@classmethod

deepmd/pt/model/descriptor/dpa2.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -829,10 +829,12 @@ def forward(
829829
g1 = torch.cat([g1, g1_inp], dim=-1)
830830
return (
831831
g1.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
832-
rot_mat.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
833-
g2.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
834-
h2.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
835-
sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
832+
rot_mat.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION)
833+
if rot_mat is not None
834+
else None,
835+
g2.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION) if g2 is not None else None,
836+
h2.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION) if h2 is not None else None,
837+
sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION) if sw is not None else None,
836838
)
837839

838840
@classmethod

deepmd/pt/model/descriptor/dpa3.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -518,10 +518,14 @@ def forward(
518518
node_ebd = torch.cat([node_ebd, node_ebd_inp], dim=-1)
519519
return (
520520
node_ebd.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
521-
rot_mat.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
522-
edge_ebd.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
523-
h2.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
524-
sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
521+
rot_mat.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION)
522+
if rot_mat is not None
523+
else None,
524+
edge_ebd.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION)
525+
if edge_ebd is not None
526+
else None,
527+
h2.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION) if h2 is not None else None,
528+
sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION) if sw is not None else None,
525529
)
526530

527531
@classmethod

deepmd/pt/model/descriptor/se_a.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -354,10 +354,12 @@ def forward(
354354
)
355355
return (
356356
g1.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
357-
rot_mat.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
357+
rot_mat.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION)
358+
if rot_mat is not None
359+
else None,
358360
None,
359361
None,
360-
sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
362+
sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION) if sw is not None else None,
361363
)
362364

363365
def set_stat_mean_and_stddev(

0 commit comments

Comments
 (0)