File tree Expand file tree Collapse file tree
deepmd/pt/model/descriptor Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -717,10 +717,12 @@ def forward(
717717
718718 return (
719719 g1 .to (dtype = env .GLOBAL_PT_FLOAT_PRECISION ),
720- rot_mat .to (dtype = env .GLOBAL_PT_FLOAT_PRECISION ),
720+ rot_mat .to (dtype = env .GLOBAL_PT_FLOAT_PRECISION )
721+ if rot_mat is not None
722+ else None ,
721723 g2 .to (dtype = env .GLOBAL_PT_FLOAT_PRECISION ) if g2 is not None else None ,
722- h2 .to (dtype = env .GLOBAL_PT_FLOAT_PRECISION ),
723- sw .to (dtype = env .GLOBAL_PT_FLOAT_PRECISION ),
724+ h2 .to (dtype = env .GLOBAL_PT_FLOAT_PRECISION ) if h2 is not None else None ,
725+ sw .to (dtype = env .GLOBAL_PT_FLOAT_PRECISION ) if sw is not None else None ,
724726 )
725727
726728 @classmethod
Original file line number Diff line number Diff line change @@ -829,10 +829,12 @@ def forward(
829829 g1 = torch .cat ([g1 , g1_inp ], dim = - 1 )
830830 return (
831831 g1 .to (dtype = env .GLOBAL_PT_FLOAT_PRECISION ),
832- rot_mat .to (dtype = env .GLOBAL_PT_FLOAT_PRECISION ),
833- g2 .to (dtype = env .GLOBAL_PT_FLOAT_PRECISION ),
834- h2 .to (dtype = env .GLOBAL_PT_FLOAT_PRECISION ),
835- sw .to (dtype = env .GLOBAL_PT_FLOAT_PRECISION ),
832+ rot_mat .to (dtype = env .GLOBAL_PT_FLOAT_PRECISION )
833+ if rot_mat is not None
834+ else None ,
835+ g2 .to (dtype = env .GLOBAL_PT_FLOAT_PRECISION ) if g2 is not None else None ,
836+ h2 .to (dtype = env .GLOBAL_PT_FLOAT_PRECISION ) if h2 is not None else None ,
837+ sw .to (dtype = env .GLOBAL_PT_FLOAT_PRECISION ) if sw is not None else None ,
836838 )
837839
838840 @classmethod
Original file line number Diff line number Diff line change @@ -518,10 +518,14 @@ def forward(
518518 node_ebd = torch .cat ([node_ebd , node_ebd_inp ], dim = - 1 )
519519 return (
520520 node_ebd .to (dtype = env .GLOBAL_PT_FLOAT_PRECISION ),
521- rot_mat .to (dtype = env .GLOBAL_PT_FLOAT_PRECISION ),
522- edge_ebd .to (dtype = env .GLOBAL_PT_FLOAT_PRECISION ),
523- h2 .to (dtype = env .GLOBAL_PT_FLOAT_PRECISION ),
524- sw .to (dtype = env .GLOBAL_PT_FLOAT_PRECISION ),
521+ rot_mat .to (dtype = env .GLOBAL_PT_FLOAT_PRECISION )
522+ if rot_mat is not None
523+ else None ,
524+ edge_ebd .to (dtype = env .GLOBAL_PT_FLOAT_PRECISION )
525+ if edge_ebd is not None
526+ else None ,
527+ h2 .to (dtype = env .GLOBAL_PT_FLOAT_PRECISION ) if h2 is not None else None ,
528+ sw .to (dtype = env .GLOBAL_PT_FLOAT_PRECISION ) if sw is not None else None ,
525529 )
526530
527531 @classmethod
Original file line number Diff line number Diff line change @@ -354,10 +354,12 @@ def forward(
354354 )
355355 return (
356356 g1 .to (dtype = env .GLOBAL_PT_FLOAT_PRECISION ),
357- rot_mat .to (dtype = env .GLOBAL_PT_FLOAT_PRECISION ),
357+ rot_mat .to (dtype = env .GLOBAL_PT_FLOAT_PRECISION )
358+ if rot_mat is not None
359+ else None ,
358360 None ,
359361 None ,
360- sw .to (dtype = env .GLOBAL_PT_FLOAT_PRECISION ),
362+ sw .to (dtype = env .GLOBAL_PT_FLOAT_PRECISION ) if sw is not None else None ,
361363 )
362364
363365 def set_stat_mean_and_stddev (
You can’t perform that action at this time.
0 commit comments