@@ -478,11 +478,12 @@ def forward(
478478 self .rcut_smth ,
479479 protection = self .env_protection ,
480480 )
481+ # nb x nloc x nnei
482+ exclude_mask = self .emask (nlist , extended_atype )
483+ nlist = torch .where (exclude_mask != 0 , nlist , - 1 )
481484 nlist_mask = nlist != - 1
482485 nlist = torch .where (nlist == - 1 , 0 , nlist )
483486 sw = torch .squeeze (sw , - 1 )
484- # beyond the cutoff sw should be 0.0
485- sw = sw .masked_fill (~ nlist_mask , 0.0 )
486487 # nf x nloc x nt -> nf x nloc x nnei x nt
487488 atype_tebd = extended_atype_embd [:, :nloc , :]
488489 atype_tebd_nnei = atype_tebd .unsqueeze (2 ).expand (- 1 , - 1 , self .nnei , - 1 )
@@ -495,8 +496,10 @@ def forward(
495496 atype_tebd_nlist = torch .gather (atype_tebd_ext , dim = 1 , index = index )
496497 # nb x nloc x nnei x nt
497498 atype_tebd_nlist = atype_tebd_nlist .view (nb , nloc , nnei , nt )
499+ # beyond the cutoff sw should be 0.0
500+ sw = sw .masked_fill (~ nlist_mask , 0.0 )
498501 # (nb x nloc) x nnei
499- exclude_mask = self . emask ( nlist , extended_atype ) .view (nb * nloc , nnei )
502+ exclude_mask = exclude_mask .view (nb * nloc , nnei )
500503 if self .old_impl :
501504 assert self .filter_layers_old is not None
502505 dmatrix = dmatrix .view (
0 commit comments