diff --git a/deepmd/dpmodel/descriptor/repflows.py b/deepmd/dpmodel/descriptor/repflows.py index f8c329b515..207c36d873 100644 --- a/deepmd/dpmodel/descriptor/repflows.py +++ b/deepmd/dpmodel/descriptor/repflows.py @@ -578,7 +578,8 @@ def call( # n_angle x 1 a_sw = (a_sw[:, :, :, None] * a_sw[:, :, None, :])[a_nlist_mask] else: - edge_index = angle_index = xp.zeros([1, 3], dtype=nlist.dtype) + edge_index = xp.zeros([2, 1], dtype=nlist.dtype) + angle_index = xp.zeros([3, 1], dtype=nlist.dtype) # get edge and angle embedding # nb x nloc x nnei x e_dim [OR] n_edge x e_dim @@ -622,7 +623,7 @@ def call( edge_ebd, h2, sw, - owner=edge_index[:, 0], + owner=edge_index[0, :], num_owner=nframes * nloc, nb=nframes, nloc=nloc, @@ -1286,8 +1287,8 @@ def call( a_nlist: np.ndarray, # nf x nloc x a_nnei a_nlist_mask: np.ndarray, # nf x nloc x a_nnei a_sw: np.ndarray, # switch func, nf x nloc x a_nnei - edge_index: np.ndarray, # n_edge x 2 - angle_index: np.ndarray, # n_angle x 3 + edge_index: np.ndarray, # 2 x n_edge + angle_index: np.ndarray, # 3 x n_angle ): """ Parameters @@ -1312,12 +1313,12 @@ def call( Masks of the neighbor list for angle. real nei 1 otherwise 0 a_sw : nf x nloc x a_nnei Switch function for angle. - edge_index : Optional for dynamic sel, n_edge x 2 + edge_index : Optional for dynamic sel, 2 x n_edge n2e_index : n_edge Broadcast indices from node(i) to edge(ij), or reduction indices from edge(ij) to node(i). n_ext2e_index : n_edge Broadcast indices from extended node(j) to edge(ij). - angle_index : Optional for dynamic sel, n_angle x 3 + angle_index : Optional for dynamic sel, 3 x n_angle n2a_index : n_angle Broadcast indices from extended node(j) to angle(ijk). eij2a_index : n_angle @@ -1362,11 +1363,11 @@ def call( assert (n_edge, 3) == h2.shape del a_nlist # may be used in the future - n2e_index, n_ext2e_index = edge_index[:, 0], edge_index[:, 1] + n2e_index, n_ext2e_index = edge_index[0, :], edge_index[1, :] n2a_index, eij2a_index, eik2a_index = ( - angle_index[:, 0], - angle_index[:, 1], - angle_index[:, 2], + angle_index[0, :], + angle_index[1, :], + angle_index[2, :], ) # nb x nloc x nnei x n_dim [OR] n_edge x n_dim diff --git a/deepmd/dpmodel/utils/network.py b/deepmd/dpmodel/utils/network.py index bf28b66b7b..c6da877d87 100644 --- a/deepmd/dpmodel/utils/network.py +++ b/deepmd/dpmodel/utils/network.py @@ -1036,12 +1036,12 @@ def get_graph_index( Returns ------- - edge_index : n_edge x 2 + edge_index : 2 x n_edge n2e_index : n_edge Broadcast indices from node(i) to edge(ij), or reduction indices from edge(ij) to node(i). n_ext2e_index : n_edge Broadcast indices from extended node(j) to edge(ij). - angle_index : n_angle x 3 + angle_index : 3 x n_angle n2a_index : n_angle Broadcast indices from extended node(j) to angle(ijk). eij2a_index : n_angle @@ -1111,7 +1111,7 @@ def get_graph_index( # n_angle eik2a_index = edge_index_ik[a_nlist_mask_3d] - edge_index_result = xp.stack([n2e_index, n_ext2e_index], axis=-1) - angle_index_result = xp.stack([n2a_index, eij2a_index, eik2a_index], axis=-1) + edge_index_result = xp.stack([n2e_index, n_ext2e_index], axis=0) + angle_index_result = xp.stack([n2a_index, eij2a_index, eik2a_index], axis=0) return edge_index_result, angle_index_result diff --git a/deepmd/pt/model/descriptor/repflow_layer.py b/deepmd/pt/model/descriptor/repflow_layer.py index 37d4f07bb4..36e738b8b2 100644 --- a/deepmd/pt/model/descriptor/repflow_layer.py +++ b/deepmd/pt/model/descriptor/repflow_layer.py @@ -370,7 +370,7 @@ def _cal_hg_dynamic( # n_edge x e_dim flat_edge_ebd = flat_edge_ebd * flat_sw.unsqueeze(-1) # n_edge x 3 x e_dim - flat_h2g2 = (flat_h2[..., None] * flat_edge_ebd[:, None, :]).reshape( + flat_h2g2 = (flat_h2.unsqueeze(-1) * flat_edge_ebd.unsqueeze(-2)).reshape( -1, 3 * e_dim ) # nf x nloc x 3 x e_dim @@ -694,8 +694,8 @@ def forward( a_nlist: torch.Tensor, # nf x nloc x a_nnei a_nlist_mask: torch.Tensor, # nf x nloc x a_nnei a_sw: torch.Tensor, # switch func, nf x nloc x a_nnei - edge_index: torch.Tensor, # n_edge x 2 - angle_index: torch.Tensor, # n_angle x 3 + edge_index: torch.Tensor, # 2 x n_edge + angle_index: torch.Tensor, # 3 x n_angle ): """ Parameters @@ -720,12 +720,12 @@ def forward( Masks of the neighbor list for angle. real nei 1 otherwise 0 a_sw : nf x nloc x a_nnei Switch function for angle. - edge_index : Optional for dynamic sel, n_edge x 2 + edge_index : Optional for dynamic sel, 2 x n_edge n2e_index : n_edge Broadcast indices from node(i) to edge(ij), or reduction indices from edge(ij) to node(i). n_ext2e_index : n_edge Broadcast indices from extended node(j) to edge(ij). - angle_index : Optional for dynamic sel, n_angle x 3 + angle_index : Optional for dynamic sel, 3 x n_angle n2a_index : n_angle Broadcast indices from extended node(j) to angle(ijk). eij2a_index : n_angle @@ -745,19 +745,21 @@ def forward( nb, nloc, nnei = nlist.shape nall = node_ebd_ext.shape[1] node_ebd = node_ebd_ext[:, :nloc, :] - n_edge = int(nlist_mask.sum().item()) assert (nb, nloc) == node_ebd.shape[:2] if not self.use_dynamic_sel: assert (nb, nloc, nnei, 3) == h2.shape + n_edge = None else: - assert (n_edge, 3) == h2.shape + # n_edge = int(nlist_mask.sum().item()) + # assert (n_edge, 3) == h2.shape + n_edge = h2.shape[0] del a_nlist # may be used in the future - n2e_index, n_ext2e_index = edge_index[:, 0], edge_index[:, 1] + n2e_index, n_ext2e_index = edge_index[0], edge_index[1] n2a_index, eij2a_index, eik2a_index = ( - angle_index[:, 0], - angle_index[:, 1], - angle_index[:, 2], + angle_index[0], + angle_index[1], + angle_index[2], ) # nb x nloc x nnei x n_dim [OR] n_edge x n_dim @@ -1026,7 +1028,9 @@ def forward( if not self.use_dynamic_sel: # nb x nloc x a_nnei x a_nnei x e_dim weighted_edge_angle_update = ( - a_sw[..., None, None] * a_sw[..., None, :, None] * edge_angle_update + a_sw.unsqueeze(-1).unsqueeze(-1) + * a_sw.unsqueeze(-2).unsqueeze(-1) + * edge_angle_update ) # nb x nloc x a_nnei x e_dim reduced_edge_angle_update = torch.sum( diff --git a/deepmd/pt/model/descriptor/repflows.py b/deepmd/pt/model/descriptor/repflows.py index 5889b0a819..68b4a8d8c6 100644 --- a/deepmd/pt/model/descriptor/repflows.py +++ b/deepmd/pt/model/descriptor/repflows.py @@ -537,9 +537,8 @@ def forward( a_sw = (a_sw[:, :, :, None] * a_sw[:, :, None, :])[a_nlist_mask] else: # avoid jit assertion - edge_index = angle_index = torch.zeros( - [1, 3], device=nlist.device, dtype=nlist.dtype - ) + edge_index = torch.zeros([2, 1], device=nlist.device, dtype=nlist.dtype) + angle_index = torch.zeros([3, 1], device=nlist.device, dtype=nlist.dtype) # get edge and angle embedding # nb x nloc x nnei x e_dim [OR] n_edge x e_dim if not self.edge_init_use_dist: @@ -646,7 +645,7 @@ def forward( edge_ebd, h2, sw, - owner=edge_index[:, 0], + owner=edge_index[0], num_owner=nframes * nloc, nb=nframes, nloc=nloc, diff --git a/deepmd/pt/model/network/utils.py b/deepmd/pt/model/network/utils.py index 2047efec2b..34af976b76 100644 --- a/deepmd/pt/model/network/utils.py +++ b/deepmd/pt/model/network/utils.py @@ -74,12 +74,12 @@ def get_graph_index( Returns ------- - edge_index : n_edge x 2 + edge_index : 2 x n_edge n2e_index : n_edge Broadcast indices from node(i) to edge(ij), or reduction indices from edge(ij) to node(i). n_ext2e_index : n_edge Broadcast indices from extended node(j) to edge(ij). - angle_index : n_angle x 3 + angle_index : 3 x n_angle n2a_index : n_angle Broadcast indices from extended node(j) to angle(ijk). eij2a_index : n_angle @@ -135,9 +135,7 @@ def get_graph_index( # n_angle eik2a_index = edge_index_ik[a_nlist_mask_3d] - return torch.cat( - [n2e_index.unsqueeze(-1), n_ext2e_index.unsqueeze(-1)], dim=-1 - ), torch.cat( - [n2a_index.unsqueeze(-1), eij2a_index.unsqueeze(-1), eik2a_index.unsqueeze(-1)], - dim=-1, - ) + edge_index_result = torch.stack([n2e_index, n_ext2e_index], dim=0) + angle_index_result = torch.stack([n2a_index, eij2a_index, eik2a_index], dim=0) + + return edge_index_result, angle_index_result