Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion deepmd/dpmodel/array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def xp_take_along_axis(arr, indices, axis):

shape = list(arr.shape)
shape.pop(-1)
shape = [*shape, n]
shape = (*shape, n)

arr = xp.reshape(arr, (-1,))
if n != 0:
Expand Down
4 changes: 2 additions & 2 deletions deepmd/dpmodel/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@ def call(
type_embedding = self.type_embedding.call()
# nf x nall x tebd_dim
atype_embd_ext = xp.reshape(
xp.take(type_embedding, xp.reshape(atype_ext, [-1]), axis=0),
xp.take(type_embedding, xp.reshape(atype_ext, (-1,)), axis=0),
(nf, nall, self.tebd_dim),
)
# nfnl x tebd_dim
Expand Down Expand Up @@ -1027,7 +1027,7 @@ def call(
xp.tile(
(xp.reshape(atype, (-1, 1)) * ntypes_with_padding), (1, nnei)
),
(-1),
(-1,),
)
idx_j = xp.reshape(nei_type, (-1,))
# (nf x nl x nnei) x ng
Expand Down
2 changes: 1 addition & 1 deletion deepmd/dpmodel/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -841,7 +841,7 @@ def call(
type_embedding = self.type_embedding.call()
# repinit
g1_ext = xp.reshape(
xp.take(type_embedding, xp.reshape(atype_ext, [-1]), axis=0),
xp.take(type_embedding, xp.reshape(atype_ext, (-1,)), axis=0),
(nframes, nall, self.tebd_dim),
)
g1_inp = g1_ext[:, :nloc, :]
Expand Down
4 changes: 2 additions & 2 deletions deepmd/dpmodel/descriptor/dpa3.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,12 +562,12 @@ def call(
type_embedding = self.type_embedding.call()
if self.use_loc_mapping:
node_ebd_ext = xp.reshape(
xp.take(type_embedding, xp.reshape(atype_ext[:, :nloc], [-1]), axis=0),
xp.take(type_embedding, xp.reshape(atype_ext[:, :nloc], (-1,)), axis=0),
(nframes, nloc, self.tebd_dim),
)
else:
node_ebd_ext = xp.reshape(
xp.take(type_embedding, xp.reshape(atype_ext, [-1]), axis=0),
xp.take(type_embedding, xp.reshape(atype_ext, (-1,)), axis=0),
(nframes, nall, self.tebd_dim),
)
node_ebd_inp = node_ebd_ext[:, :nloc, :]
Expand Down
2 changes: 1 addition & 1 deletion deepmd/dpmodel/descriptor/se_t_tebd.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def call(
type_embedding = self.type_embedding.call()
# nf x nall x tebd_dim
atype_embd_ext = xp.reshape(
xp.take(type_embedding, xp.reshape(atype_ext, [-1]), axis=0),
xp.take(type_embedding, xp.reshape(atype_ext, (-1,)), axis=0),
(nf, nall, self.tebd_dim),
)
# nfnl x tebd_dim
Expand Down
12 changes: 6 additions & 6 deletions deepmd/dpmodel/fitting/general_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ def _call_common(
)
fparam = (fparam - self.fparam_avg[...]) * self.fparam_inv_std[...]
fparam = xp.tile(
xp.reshape(fparam, [nf, 1, self.numb_fparam]), (1, nloc, 1)
xp.reshape(fparam, (nf, 1, self.numb_fparam)), (1, nloc, 1)
)
xx = xp.concat(
[xx, fparam],
Expand All @@ -431,7 +431,7 @@ def _call_common(
f"get an input aparam of dim {aparam.shape[-1]}, "
f"which is not consistent with {self.numb_aparam}."
)
aparam = xp.reshape(aparam, [nf, nloc, self.numb_aparam])
aparam = xp.reshape(aparam, (nf, nloc, self.numb_aparam))
aparam = (aparam - self.aparam_avg[...]) * self.aparam_inv_std[...]
xx = xp.concat(
[xx, aparam],
Expand All @@ -446,7 +446,7 @@ def _call_common(
if self.dim_case_embd > 0:
assert self.case_embd is not None
case_embd = xp.tile(
xp.reshape(self.case_embd[...], [1, 1, -1]), [nf, nloc, 1]
xp.reshape(self.case_embd[...], (1, 1, -1)), (nf, nloc, 1)
)
xx = xp.concat(
[xx, case_embd],
Expand All @@ -465,7 +465,7 @@ def _call_common(
)
for type_i in range(self.ntypes):
mask = xp.tile(
xp.reshape((atype == type_i), [nf, nloc, 1]), (1, 1, net_dim_out)
xp.reshape((atype == type_i), (nf, nloc, 1)), (1, 1, net_dim_out)
)
atom_property = self.nets[(type_i,)](xx)
if self.remove_vaccum_contribution is not None and not (
Expand All @@ -485,10 +485,10 @@ def _call_common(
outs += xp.reshape(
xp.take(
xp.astype(self.bias_atom_e[...], outs.dtype),
xp.reshape(atype, [-1]),
xp.reshape(atype, (-1,)),
axis=0,
),
[nf, nloc, net_dim_out],
(nf, nloc, net_dim_out),
)
# nf x nloc
exclude_mask = self.emask.build_type_exclude_mask(atype)
Expand Down
4 changes: 2 additions & 2 deletions deepmd/dpmodel/fitting/polarizability_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def call(
]
# out = out * self.scale[atype, ...]
scale_atype = xp.reshape(
xp.take(xp.astype(self.scale, out.dtype), xp.reshape(atype, [-1]), axis=0),
xp.take(xp.astype(self.scale, out.dtype), xp.reshape(atype, (-1,)), axis=0),
(*atype.shape, 1),
)
out = out * scale_atype
Expand All @@ -315,7 +315,7 @@ def call(
bias = xp.reshape(
xp.take(
xp.astype(self.constant_matrix, out.dtype),
xp.reshape(atype, [-1]),
xp.reshape(atype, (-1,)),
axis=0,
),
(nframes, nloc),
Expand Down
32 changes: 16 additions & 16 deletions deepmd/dpmodel/loss/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,18 +132,18 @@
atom_ener_coeff = xp.reshape(atom_ener_coeff, xp.shape(atom_ener))
energy = xp.sum(atom_ener_coeff * atom_ener, 1)
if self.has_f or self.has_pf or self.relative_f or self.has_gf:
force_reshape = xp.reshape(force, [-1])
force_hat_reshape = xp.reshape(force_hat, [-1])
force_reshape = xp.reshape(force, (-1,))
force_hat_reshape = xp.reshape(force_hat, (-1,))
diff_f = force_hat_reshape - force_reshape
else:
diff_f = None

if self.relative_f is not None:
force_hat_3 = xp.reshape(force_hat, [-1, 3])
norm_f = xp.reshape(xp.norm(force_hat_3, axis=1), [-1, 1]) + self.relative_f
diff_f_3 = xp.reshape(diff_f, [-1, 3])
force_hat_3 = xp.reshape(force_hat, (-1, 3))
norm_f = xp.reshape(xp.norm(force_hat_3, axis=1), (-1, 1)) + self.relative_f
diff_f_3 = xp.reshape(diff_f, (-1, 3))

Check warning on line 144 in deepmd/dpmodel/loss/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/loss/ener.py#L142-L144

Added lines #L142 - L144 were not covered by tests
diff_f_3 = diff_f_3 / norm_f
diff_f = xp.reshape(diff_f_3, [-1])
diff_f = xp.reshape(diff_f_3, (-1,))

Check warning on line 146 in deepmd/dpmodel/loss/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/loss/ener.py#L146

Added line #L146 was not covered by tests

atom_norm = 1.0 / natoms
atom_norm_ener = 1.0 / natoms
Expand Down Expand Up @@ -184,15 +184,15 @@
loss += pref_f * l2_force_loss
else:
l_huber_loss = custom_huber_loss(
xp.reshape(force, [-1]),
xp.reshape(force_hat, [-1]),
xp.reshape(force, (-1,)),
xp.reshape(force_hat, (-1,)),
delta=self.huber_delta,
)
loss += pref_f * l_huber_loss
more_loss["rmse_f"] = self.display_if_exist(l2_force_loss, find_force)
if self.has_v:
virial_reshape = xp.reshape(virial, [-1])
virial_hat_reshape = xp.reshape(virial_hat, [-1])
virial_reshape = xp.reshape(virial, (-1,))
virial_hat_reshape = xp.reshape(virial_hat, (-1,))
l2_virial_loss = xp.mean(
xp.square(virial_hat_reshape - virial_reshape),
)
Expand All @@ -207,8 +207,8 @@
loss += pref_v * l_huber_loss
more_loss["rmse_v"] = self.display_if_exist(l2_virial_loss, find_virial)
if self.has_ae:
atom_ener_reshape = xp.reshape(atom_ener, [-1])
atom_ener_hat_reshape = xp.reshape(atom_ener_hat, [-1])
atom_ener_reshape = xp.reshape(atom_ener, (-1,))
atom_ener_hat_reshape = xp.reshape(atom_ener_hat, (-1,))
l2_atom_ener_loss = xp.mean(
xp.square(atom_ener_hat_reshape - atom_ener_reshape),
)
Expand All @@ -225,7 +225,7 @@
l2_atom_ener_loss, find_atom_ener
)
if self.has_pf:
atom_pref_reshape = xp.reshape(atom_pref, [-1])
atom_pref_reshape = xp.reshape(atom_pref, (-1,))
l2_pref_force_loss = xp.mean(
xp.multiply(xp.square(diff_f), atom_pref_reshape),
)
Expand All @@ -236,10 +236,10 @@
if self.has_gf:
find_drdq = label_dict["find_drdq"]
drdq = label_dict["drdq"]
force_reshape_nframes = xp.reshape(force, [-1, natoms[0] * 3])
force_hat_reshape_nframes = xp.reshape(force_hat, [-1, natoms[0] * 3])
force_reshape_nframes = xp.reshape(force, (-1, natoms[0] * 3))
force_hat_reshape_nframes = xp.reshape(force_hat, (-1, natoms[0] * 3))

Check warning on line 240 in deepmd/dpmodel/loss/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/loss/ener.py#L239-L240

Added lines #L239 - L240 were not covered by tests
drdq_reshape = xp.reshape(
drdq, [-1, natoms[0] * 3, self.numb_generalized_coord]
drdq, (-1, natoms[0] * 3, self.numb_generalized_coord)
)
gen_force_hat = xp.einsum(
"bij,bi->bj", drdq_reshape, force_hat_reshape_nframes
Expand Down
4 changes: 3 additions & 1 deletion deepmd/dpmodel/model/transform_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,9 @@ def communicate_extended_output(
if vdef.r_differentiable:
if model_ret[kk_derv_r] is not None:
derv_r_ext_dims = list(vdef.shape) + [3] # noqa:RUF005
mapping = xp.reshape(mapping, (mldims + [1] * len(derv_r_ext_dims)))
mapping = xp.reshape(
mapping, tuple(mldims + [1] * len(derv_r_ext_dims))
)
mapping = xp.tile(mapping, [1] * len(mldims) + derv_r_ext_dims)
force = xp.zeros(vldims + derv_r_ext_dims, dtype=vv.dtype)
force = xp_scatter_sum(
Expand Down
4 changes: 2 additions & 2 deletions deepmd/dpmodel/utils/env_mat_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def iter(
self.last_dim,
),
)
atype = xp.reshape(atype, (coord.shape[0] * coord.shape[1]))
atype = xp.reshape(atype, (coord.shape[0] * coord.shape[1],))
# (1, nloc) eq (ntypes, 1), so broadcast is possible
# shape: (ntypes, nloc)
type_idx = xp.equal(
Expand All @@ -189,7 +189,7 @@ def iter(
for type_i in range(self.descriptor.get_ntypes()):
dd = env_mat[type_idx[type_i, ...]]
dd = xp.reshape(
dd, [-1, self.last_dim]
dd, (-1, self.last_dim)
) # typen_atoms * unmasked_nnei, 4
env_mats = {}
env_mats[f"r_{type_i}"] = dd[:, :1]
Expand Down
2 changes: 1 addition & 1 deletion deepmd/dpmodel/utils/exclude_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def build_type_exclude_mask(
xp = array_api_compat.array_namespace(atype)
nf, natom = atype.shape
return xp.reshape(
xp.take(self.type_mask[...], xp.reshape(atype, [-1]), axis=0),
xp.take(self.type_mask[...], xp.reshape(atype, (-1,)), axis=0),
(nf, natom),
)

Expand Down
4 changes: 2 additions & 2 deletions deepmd/dpmodel/utils/neighbor_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ def call(
nall = coord1.shape[1] // 3
coord0 = coord1[:, : nloc * 3]
diff = (
xp.reshape(coord1, [nframes, -1, 3])[:, None, :, :]
- xp.reshape(coord0, [nframes, -1, 3])[:, :, None, :]
xp.reshape(coord1, (nframes, -1, 3))[:, None, :, :]
- xp.reshape(coord0, (nframes, -1, 3))[:, :, None, :]
)
assert list(diff.shape) == [nframes, nloc, nall, 3]
# remove the diagonal elements
Expand Down
4 changes: 2 additions & 2 deletions deepmd/dpmodel/utils/nlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,8 @@ def build_neighbor_list(
nsel = sum(sel)
coord0 = coord1[:, : nloc * 3]
diff = (
xp.reshape(coord1, [batch_size, -1, 3])[:, None, :, :]
- xp.reshape(coord0, [batch_size, -1, 3])[:, :, None, :]
xp.reshape(coord1, (batch_size, -1, 3))[:, None, :, :]
- xp.reshape(coord0, (batch_size, -1, 3))[:, :, None, :]
)
assert list(diff.shape) == [batch_size, nloc, nall, 3]
rr = xp.linalg.vector_norm(diff, axis=-1)
Expand Down
4 changes: 2 additions & 2 deletions deepmd/dpmodel/utils/region.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ def to_face_distance(
"""
xp = array_api_compat.array_namespace(cell)
cshape = cell.shape
dist = b_to_face_distance(xp.reshape(cell, [-1, 3, 3]))
return xp.reshape(dist, list(cshape[:-2]) + [3]) # noqa:RUF005
dist = b_to_face_distance(xp.reshape(cell, (-1, 3, 3)))
return xp.reshape(dist, tuple(list(cshape[:-2]) + [3])) # noqa:RUF005


def b_to_face_distance(cell):
Expand Down