Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions deepmd/pt/model/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,16 +92,14 @@ def get_model_sels(self) -> List[List[int]]:
"""Get the sels for each individual models."""
return [model.get_sel() for model in self.models]

def _sort_rcuts_sels(self) -> Tuple[List[float], List[int]]:
def _sort_rcuts_sels(self, device: torch.device) -> Tuple[List[float], List[int]]:
# sort the pair of rcut and sels in ascending order, first based on sel, then on rcut.
rcuts = torch.tensor(
self.get_model_rcuts(), dtype=torch.float64, device=env.DEVICE
)
nsels = torch.tensor(self.get_model_nsels(), device=env.DEVICE)
rcuts = torch.tensor(self.get_model_rcuts(), dtype=torch.float64, device=device)
nsels = torch.tensor(self.get_model_nsels(), device=device)
zipped = torch.stack(
[
torch.tensor(rcuts, device=env.DEVICE),
torch.tensor(nsels, device=env.DEVICE),
torch.tensor(rcuts, device=device),
torch.tensor(nsels, device=device),
],
dim=0,
).T
Expand Down Expand Up @@ -148,7 +146,7 @@ def forward_atomic(
if self.do_grad_r() or self.do_grad_c():
extended_coord.requires_grad_(True)
extended_coord = extended_coord.view(nframes, -1, 3)
sorted_rcuts, sorted_sels = self._sort_rcuts_sels()
sorted_rcuts, sorted_sels = self._sort_rcuts_sels(device=extended_coord.device)
nlists = build_multiple_neighbor_list(
extended_coord,
nlist,
Expand Down
9 changes: 4 additions & 5 deletions deepmd/pt/model/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@
FittingOutputDef,
OutputVariableDef,
)
from deepmd.pt.utils import (
env,
)
from deepmd.utils.pair_tab import (
PairTab,
)
Expand Down Expand Up @@ -160,15 +157,17 @@ def forward_atomic(
pairwise_rr = self._get_pairwise_dist(
extended_coord, masked_nlist
) # (nframes, nloc, nnei)
self.tab_data = self.tab_data.to(device=env.DEVICE).view(
self.tab_data = self.tab_data.to(device=extended_coord.device).view(
int(self.tab_info[-1]), int(self.tab_info[-1]), int(self.tab_info[2]), 4
)

# to calculate the atomic_energy, we need 3 tensors, i_type, j_type, pairwise_rr
# i_type : (nframes, nloc), this is atype.
# j_type : (nframes, nloc, nnei)
j_type = extended_atype[
torch.arange(extended_atype.size(0), device=env.DEVICE)[:, None, None],
torch.arange(extended_atype.size(0), device=extended_coord.device)[
:, None, None
],
masked_nlist,
]

Expand Down
4 changes: 2 additions & 2 deletions deepmd/pt/model/descriptor/repformer_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ def _update_g1_conv(
else:
gg1 = _apply_switch(gg1, sw)
invnnei = (1.0 / float(nnei)) * torch.ones(
(nb, nloc, 1), dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
(nb, nloc, 1), dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=gg1.device
)
# nb x nloc x ng2
g1_11 = torch.sum(g2 * gg1, dim=2) * invnnei
Expand Down Expand Up @@ -474,7 +474,7 @@ def _cal_h2g2(
else:
g2 = _apply_switch(g2, sw)
invnnei = (1.0 / float(nnei)) * torch.ones(
(nb, nloc, 1, 1), dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
(nb, nloc, 1, 1), dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=g2.device
)
# nb x nloc x 3 x ng2
h2g2 = torch.matmul(torch.transpose(h2, -1, -2), g2) * invnnei
Expand Down
4 changes: 3 additions & 1 deletion deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,9 @@ def forward(
nfnl = dmatrix.shape[0]
# pre-allocate a shape to pass jit
xyz_scatter = torch.zeros(
[nfnl, 4, self.filter_neuron[-1]], dtype=self.prec, device=env.DEVICE
[nfnl, 4, self.filter_neuron[-1]],
dtype=self.prec,
device=extended_coord.device,
)
# nfnl x nnei
exclude_mask = self.emask(nlist, extended_atype).view(nfnl, -1)
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/model/task/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,7 @@ def _forward_common(
outs = torch.zeros(
(nf, nloc, net_dim_out),
dtype=env.GLOBAL_PT_FLOAT_PRECISION,
device=env.DEVICE,
device=descriptor.device,
) # jit assertion
if self.old_impl:
assert self.filter_layers_old is not None
Expand Down
21 changes: 11 additions & 10 deletions deepmd/pt/utils/nlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,8 +288,9 @@ def extend_coord_with_ghosts(
maping extended index to the local index

"""
device = coord.device
nf, nloc = atype.shape
aidx = torch.tile(torch.arange(nloc, device=env.DEVICE).unsqueeze(0), [nf, 1])
aidx = torch.tile(torch.arange(nloc, device=device).unsqueeze(0), [nf, 1])
if cell is None:
nall = nloc
extend_coord = coord.clone()
Expand All @@ -306,17 +307,17 @@ def extend_coord_with_ghosts(
nbuff = torch.ceil(rcut / to_face).to(torch.long)
# 3
nbuff = torch.max(nbuff, dim=0, keepdim=False).values
xi = torch.arange(-nbuff[0], nbuff[0] + 1, 1, device=env.DEVICE)
yi = torch.arange(-nbuff[1], nbuff[1] + 1, 1, device=env.DEVICE)
zi = torch.arange(-nbuff[2], nbuff[2] + 1, 1, device=env.DEVICE)
xi = torch.arange(-nbuff[0], nbuff[0] + 1, 1, device=device)
yi = torch.arange(-nbuff[1], nbuff[1] + 1, 1, device=device)
zi = torch.arange(-nbuff[2], nbuff[2] + 1, 1, device=device)
xyz = xi.view(-1, 1, 1, 1) * torch.tensor(
[1, 0, 0], dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
[1, 0, 0], dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=device
)
xyz = xyz + yi.view(1, -1, 1, 1) * torch.tensor(
[0, 1, 0], dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
[0, 1, 0], dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=device
)
xyz = xyz + zi.view(1, 1, -1, 1) * torch.tensor(
[0, 0, 1], dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
[0, 0, 1], dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=device
)
xyz = xyz.view(-1, 3)
# ns x 3
Expand All @@ -333,7 +334,7 @@ def extend_coord_with_ghosts(
extend_aidx = torch.tile(aidx.unsqueeze(-2), [1, ns, 1])

return (
extend_coord.reshape([nf, nall * 3]).to(env.DEVICE),
extend_atype.view([nf, nall]).to(env.DEVICE),
extend_aidx.view([nf, nall]).to(env.DEVICE),
extend_coord.reshape([nf, nall * 3]).to(device),
extend_atype.view([nf, nall]).to(device),
extend_aidx.view([nf, nall]).to(device),
)
8 changes: 8 additions & 0 deletions source/tests/pt/model/test_deeppot.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,11 @@ def setUp(self):
)
freeze(ns)
self.model = frozen_model

# Note: this can not actually disable cuda device to be used
# only can be used to test whether devices are mismatched
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@unittest.mock.patch("deepmd.pt.utils.env.DEVICE", torch.device("cpu"))
@unittest.mock.patch("deepmd.pt.infer.deep_eval.DEVICE", torch.device("cpu"))
def test_dp_test_cpu(self):
self.test_dp_test()
28 changes: 19 additions & 9 deletions source/tests/pt/model/test_dipole_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
)
from deepmd.pt.utils.utils import (
to_numpy_array,
to_torch_tensor,
)

from .test_env_mat import (
Expand Down Expand Up @@ -298,10 +299,10 @@ def setUp(self):
self.rcut_smth = 0.5
self.sel = [46, 92, 4]
self.nf = 1
self.coord = 2 * torch.rand([self.natoms, 3], dtype=dtype, device="cpu")
cell = torch.rand([3, 3], dtype=dtype, device="cpu")
self.cell = (cell + cell.T) + 5.0 * torch.eye(3, device="cpu")
self.atype = torch.IntTensor([0, 0, 0, 1, 1], device="cpu")
self.coord = 2 * torch.rand([self.natoms, 3], dtype=dtype, device=env.DEVICE)
cell = torch.rand([3, 3], dtype=dtype, device=env.DEVICE)
self.cell = (cell + cell.T) + 5.0 * torch.eye(3, device=env.DEVICE)
self.atype = torch.IntTensor([0, 0, 0, 1, 1], device="cpu").to(env.DEVICE)
self.dd0 = DescrptSeA(self.rcut, self.rcut_smth, self.sel).to(env.DEVICE)
self.ft0 = DipoleFittingNet(
"dipole",
Expand All @@ -322,17 +323,26 @@ def test_auto_diff(self):
atype = self.atype.view(self.nf, self.natoms)

def ff(coord, atype):
return self.model(coord, atype)["global_dipole"].detach().cpu().numpy()
return (
self.model(to_torch_tensor(coord), to_torch_tensor(atype))[
"global_dipole"
]
.detach()
.cpu()
.numpy()
)

fdf = -finite_difference(ff, self.coord, atype, delta=delta)
fdf = -finite_difference(
ff, to_numpy_array(self.coord), to_numpy_array(atype), delta=delta
)
rff = self.model(self.coord, atype)["force"].detach().cpu().numpy()

np.testing.assert_almost_equal(fdf, rff.transpose(0, 2, 1, 3), decimal=places)

def test_deepdipole_infer(self):
atype = self.atype.view(self.nf, self.natoms)
coord = self.coord.reshape(1, 5, 3)
cell = self.cell.reshape(1, 9)
atype = to_numpy_array(self.atype.view(self.nf, self.natoms))
coord = to_numpy_array(self.coord.reshape(1, 5, 3))
cell = to_numpy_array(self.cell.reshape(1, 9))
jit_md = torch.jit.script(self.model)
torch.jit.save(jit_md, self.file_path)
load_md = DeepDipole(self.file_path)
Expand Down