Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions deepmd/dpmodel/descriptor/repflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,8 @@ def call(
# n_angle x 1
a_sw = (a_sw[:, :, :, None] * a_sw[:, :, None, :])[a_nlist_mask]
else:
edge_index = angle_index = xp.zeros([1, 3], dtype=nlist.dtype)
edge_index = xp.zeros([2, 1], dtype=nlist.dtype)
angle_index = xp.zeros([3, 1], dtype=nlist.dtype)

# get edge and angle embedding
# nb x nloc x nnei x e_dim [OR] n_edge x e_dim
Expand Down Expand Up @@ -622,7 +623,7 @@ def call(
edge_ebd,
h2,
sw,
owner=edge_index[:, 0],
owner=edge_index[0, :],
num_owner=nframes * nloc,
nb=nframes,
nloc=nloc,
Expand Down Expand Up @@ -1286,8 +1287,8 @@ def call(
a_nlist: np.ndarray, # nf x nloc x a_nnei
a_nlist_mask: np.ndarray, # nf x nloc x a_nnei
a_sw: np.ndarray, # switch func, nf x nloc x a_nnei
edge_index: np.ndarray, # n_edge x 2
angle_index: np.ndarray, # n_angle x 3
edge_index: np.ndarray, # 2 x n_edge
angle_index: np.ndarray, # 3 x n_angle
):
"""
Parameters
Expand All @@ -1312,12 +1313,12 @@ def call(
Masks of the neighbor list for angle. real nei 1 otherwise 0
a_sw : nf x nloc x a_nnei
Switch function for angle.
edge_index : Optional for dynamic sel, n_edge x 2
edge_index : Optional for dynamic sel, 2 x n_edge
n2e_index : n_edge
Broadcast indices from node(i) to edge(ij), or reduction indices from edge(ij) to node(i).
n_ext2e_index : n_edge
Broadcast indices from extended node(j) to edge(ij).
angle_index : Optional for dynamic sel, n_angle x 3
angle_index : Optional for dynamic sel, 3 x n_angle
n2a_index : n_angle
Broadcast indices from extended node(j) to angle(ijk).
eij2a_index : n_angle
Expand Down Expand Up @@ -1362,11 +1363,11 @@ def call(
assert (n_edge, 3) == h2.shape
del a_nlist # may be used in the future

n2e_index, n_ext2e_index = edge_index[:, 0], edge_index[:, 1]
n2e_index, n_ext2e_index = edge_index[0, :], edge_index[1, :]
n2a_index, eij2a_index, eik2a_index = (
angle_index[:, 0],
angle_index[:, 1],
angle_index[:, 2],
angle_index[0, :],
angle_index[1, :],
angle_index[2, :],
)

# nb x nloc x nnei x n_dim [OR] n_edge x n_dim
Expand Down
8 changes: 4 additions & 4 deletions deepmd/dpmodel/utils/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -1036,12 +1036,12 @@ def get_graph_index(

Returns
-------
edge_index : n_edge x 2
edge_index : 2 x n_edge
n2e_index : n_edge
Broadcast indices from node(i) to edge(ij), or reduction indices from edge(ij) to node(i).
n_ext2e_index : n_edge
Broadcast indices from extended node(j) to edge(ij).
angle_index : n_angle x 3
angle_index : 3 x n_angle
n2a_index : n_angle
Broadcast indices from extended node(j) to angle(ijk).
eij2a_index : n_angle
Expand Down Expand Up @@ -1111,7 +1111,7 @@ def get_graph_index(
# n_angle
eik2a_index = edge_index_ik[a_nlist_mask_3d]

Comment thread
caic99 marked this conversation as resolved.
edge_index_result = xp.stack([n2e_index, n_ext2e_index], axis=-1)
angle_index_result = xp.stack([n2a_index, eij2a_index, eik2a_index], axis=-1)
edge_index_result = xp.stack([n2e_index, n_ext2e_index], axis=0)
angle_index_result = xp.stack([n2a_index, eij2a_index, eik2a_index], axis=0)

return edge_index_result, angle_index_result
28 changes: 16 additions & 12 deletions deepmd/pt/model/descriptor/repflow_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ def _cal_hg_dynamic(
# n_edge x e_dim
flat_edge_ebd = flat_edge_ebd * flat_sw.unsqueeze(-1)
# n_edge x 3 x e_dim
flat_h2g2 = (flat_h2[..., None] * flat_edge_ebd[:, None, :]).reshape(
flat_h2g2 = (flat_h2.unsqueeze(-1) * flat_edge_ebd.unsqueeze(-2)).reshape(
-1, 3 * e_dim
)
# nf x nloc x 3 x e_dim
Expand Down Expand Up @@ -694,8 +694,8 @@ def forward(
a_nlist: torch.Tensor, # nf x nloc x a_nnei
a_nlist_mask: torch.Tensor, # nf x nloc x a_nnei
a_sw: torch.Tensor, # switch func, nf x nloc x a_nnei
edge_index: torch.Tensor, # n_edge x 2
angle_index: torch.Tensor, # n_angle x 3
edge_index: torch.Tensor, # 2 x n_edge
angle_index: torch.Tensor, # 3 x n_angle
):
"""
Parameters
Expand All @@ -720,12 +720,12 @@ def forward(
Masks of the neighbor list for angle. real nei 1 otherwise 0
a_sw : nf x nloc x a_nnei
Switch function for angle.
edge_index : Optional for dynamic sel, n_edge x 2
edge_index : Optional for dynamic sel, 2 x n_edge
n2e_index : n_edge
Broadcast indices from node(i) to edge(ij), or reduction indices from edge(ij) to node(i).
n_ext2e_index : n_edge
Broadcast indices from extended node(j) to edge(ij).
angle_index : Optional for dynamic sel, n_angle x 3
angle_index : Optional for dynamic sel, 3 x n_angle
n2a_index : n_angle
Broadcast indices from extended node(j) to angle(ijk).
eij2a_index : n_angle
Expand All @@ -745,19 +745,21 @@ def forward(
nb, nloc, nnei = nlist.shape
nall = node_ebd_ext.shape[1]
node_ebd = node_ebd_ext[:, :nloc, :]
n_edge = int(nlist_mask.sum().item())
assert (nb, nloc) == node_ebd.shape[:2]
if not self.use_dynamic_sel:
assert (nb, nloc, nnei, 3) == h2.shape
n_edge = None
else:
assert (n_edge, 3) == h2.shape
# n_edge = int(nlist_mask.sum().item())
# assert (n_edge, 3) == h2.shape
Comment thread
caic99 marked this conversation as resolved.
n_edge = h2.shape[0]
Comment thread
caic99 marked this conversation as resolved.
Comment thread
caic99 marked this conversation as resolved.
del a_nlist # may be used in the future

n2e_index, n_ext2e_index = edge_index[:, 0], edge_index[:, 1]
n2e_index, n_ext2e_index = edge_index[0], edge_index[1]
n2a_index, eij2a_index, eik2a_index = (
angle_index[:, 0],
angle_index[:, 1],
angle_index[:, 2],
angle_index[0],
angle_index[1],
angle_index[2],
)
Comment thread
iProzd marked this conversation as resolved.

# nb x nloc x nnei x n_dim [OR] n_edge x n_dim
Expand Down Expand Up @@ -1026,7 +1028,9 @@ def forward(
if not self.use_dynamic_sel:
# nb x nloc x a_nnei x a_nnei x e_dim
weighted_edge_angle_update = (
a_sw[..., None, None] * a_sw[..., None, :, None] * edge_angle_update
a_sw.unsqueeze(-1).unsqueeze(-1)
* a_sw.unsqueeze(-2).unsqueeze(-1)
* edge_angle_update
)
# nb x nloc x a_nnei x e_dim
reduced_edge_angle_update = torch.sum(
Expand Down
7 changes: 3 additions & 4 deletions deepmd/pt/model/descriptor/repflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,9 +537,8 @@ def forward(
a_sw = (a_sw[:, :, :, None] * a_sw[:, :, None, :])[a_nlist_mask]
else:
# avoid jit assertion
edge_index = angle_index = torch.zeros(
[1, 3], device=nlist.device, dtype=nlist.dtype
)
edge_index = torch.zeros([2, 1], device=nlist.device, dtype=nlist.dtype)
angle_index = torch.zeros([3, 1], device=nlist.device, dtype=nlist.dtype)
Comment thread
caic99 marked this conversation as resolved.
# get edge and angle embedding
# nb x nloc x nnei x e_dim [OR] n_edge x e_dim
if not self.edge_init_use_dist:
Expand Down Expand Up @@ -646,7 +645,7 @@ def forward(
edge_ebd,
h2,
sw,
owner=edge_index[:, 0],
owner=edge_index[0],
num_owner=nframes * nloc,
nb=nframes,
nloc=nloc,
Expand Down
14 changes: 6 additions & 8 deletions deepmd/pt/model/network/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,12 @@ def get_graph_index(

Returns
-------
edge_index : n_edge x 2
edge_index : 2 x n_edge
n2e_index : n_edge
Broadcast indices from node(i) to edge(ij), or reduction indices from edge(ij) to node(i).
n_ext2e_index : n_edge
Broadcast indices from extended node(j) to edge(ij).
angle_index : n_angle x 3
angle_index : 3 x n_angle
n2a_index : n_angle
Broadcast indices from extended node(j) to angle(ijk).
eij2a_index : n_angle
Expand Down Expand Up @@ -135,9 +135,7 @@ def get_graph_index(
# n_angle
eik2a_index = edge_index_ik[a_nlist_mask_3d]

return torch.cat(
[n2e_index.unsqueeze(-1), n_ext2e_index.unsqueeze(-1)], dim=-1
), torch.cat(
[n2a_index.unsqueeze(-1), eij2a_index.unsqueeze(-1), eik2a_index.unsqueeze(-1)],
dim=-1,
)
edge_index_result = torch.stack([n2e_index, n_ext2e_index], dim=0)
angle_index_result = torch.stack([n2a_index, eij2a_index, eik2a_index], dim=0)

return edge_index_result, angle_index_result