Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions deepmd/dpmodel/model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ def get_sel(self) -> List[int]:
"""Get the neighbor selection."""
return self.descriptor.get_sel()

def get_type_map(self) -> Optional[List[str]]:
"""Get the type map."""
return self.type_map

def distinguish_types(self) -> bool:
"""Returns if model requires a neighbor list that distinguish different
atomic types or not.
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ def get_rcut(self) -> float:
"""Get the cut-off radius."""
return max(self.get_model_rcuts())

def get_type_map(self) -> Optional[List[str]]:
"""Get the type map."""
raise NotImplementedError("TODO: get_type_map should be implemented")

def get_model_rcuts(self) -> List[float]:
"""Get the cut-off radius for each individual models."""
return [model.get_rcut() for model in self.models]
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/model/make_base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ def get_rcut(self) -> float:
"""Get the cut-off radius."""
pass

@abstractmethod
def get_type_map(self) -> Optional[List[str]]:
"""Get the type map."""

@abstractmethod
def get_sel(self) -> List[int]:
"""Returns the number of selected atoms for each type."""
Expand Down
3 changes: 3 additions & 0 deletions deepmd/dpmodel/model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ def fitting_output_def(self) -> FittingOutputDef:
def get_rcut(self) -> float:
return self.rcut

def get_type_map(self) -> Optional[List[str]]:
raise NotImplementedError("TODO: get_type_map should be implemented")

def get_sel(self) -> List[int]:
return [self.sel]

Expand Down
6 changes: 6 additions & 0 deletions deepmd/pt/model/model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,16 @@ def fitting_output_def(self) -> FittingOutputDef:
else self.coord_denoise_net.output_def()
)

@torch.jit.export
def get_rcut(self) -> float:
"""Get the cut-off radius."""
return self.rcut

@torch.jit.export
def get_type_map(self) -> List[str]:
"""Get the type map."""
return self.type_map

def get_sel(self) -> List[int]:
"""Get the neighbor selection."""
return self.sel
Expand Down
6 changes: 6 additions & 0 deletions deepmd/pt/model/model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,16 @@ def distinguish_types(self) -> bool:
"""If distinguish different types by sorting."""
return False

@torch.jit.export
def get_rcut(self) -> float:
"""Get the cut-off radius."""
return max(self.get_model_rcuts())

@torch.jit.export
def get_type_map(self) -> List[str]:
"""Get the type map."""
raise NotImplementedError("TODO: implement this method")

def get_model_rcuts(self) -> List[float]:
"""Get the cut-off radius for each individual models."""
return [model.get_rcut() for model in self.models]
Expand Down
5 changes: 5 additions & 0 deletions deepmd/pt/model/model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,14 @@ def fitting_output_def(self) -> FittingOutputDef:
]
)

@torch.jit.export
def get_rcut(self) -> float:
return self.rcut

@torch.jit.export
def get_type_map(self) -> Optional[List[str]]:
raise NotImplementedError("TODO: implement this method")

def get_sel(self) -> List[int]:
return [self.sel]

Expand Down
4 changes: 3 additions & 1 deletion source/tests/pt/model/test_dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,4 +107,6 @@ def test_jit(self):
).to(env.DEVICE)
type_map = ["foo", "bar"]
md0 = DPAtomicModel(ds, ft, type_map=type_map).to(env.DEVICE)
torch.jit.script(md0)
md0 = torch.jit.script(md0)
self.assertEqual(md0.get_rcut(), self.rcut)
self.assertEqual(md0.get_type_map(), type_map)
8 changes: 6 additions & 2 deletions source/tests/pt/model/test_dp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,9 @@ def test_jit(self):
).to(env.DEVICE)
type_map = ["foo", "bar"]
md0 = DPModel(ds, ft, type_map=type_map).to(env.DEVICE)
torch.jit.script(md0)
md0 = torch.jit.script(md0)
md0.get_rcut()
md0.get_type_map()


class TestDPModelFormatNlist(unittest.TestCase):
Expand Down Expand Up @@ -521,4 +523,6 @@ def test_jit(self):
).to(env.DEVICE)
type_map = ["foo", "bar"]
md0 = EnergyModel(ds, ft, type_map=type_map).to(env.DEVICE)
torch.jit.script(md0)
md0 = torch.jit.script(md0)
self.assertEqual(md0.get_rcut(), self.rcut)
self.assertEqual(md0.get_type_map(), type_map)
10 changes: 8 additions & 2 deletions source/tests/pt/model/test_linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,14 @@ def test_self_consistency(self):
)

def test_jit(self):
torch.jit.script(self.md1)
torch.jit.script(self.md3)
md1 = torch.jit.script(self.md1)
self.assertEqual(md1.get_rcut(), self.rcut)
with self.assertRaises(torch.jit.Error):
self.assertEqual(md1.get_type_map(), ["foo", "bar"])
md3 = torch.jit.script(self.md3)
self.assertEqual(md3.get_rcut(), self.rcut)
with self.assertRaises(torch.jit.Error):
self.assertEqual(md3.get_type_map(), ["foo", "bar"])


if __name__ == "__main__":
Expand Down
6 changes: 6 additions & 0 deletions source/tests/pt/model/test_pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ def test_with_mask(self):

def test_jit(self):
model = torch.jit.script(self.model)
self.assertEqual(model.get_rcut(), 0.02)
with self.assertRaises(torch.jit.Error):
self.assertEqual(model.get_type_map(), None)

def test_deserialize(self):
model1 = PairTabModel.deserialize(self.model.serialize())
Expand All @@ -101,6 +104,9 @@ def test_deserialize(self):
)

model1 = torch.jit.script(model1)
self.assertEqual(model1.get_rcut(), 0.02)
with self.assertRaises(torch.jit.Error):
self.assertEqual(model1.get_type_map(), None)

def test_cross_deserialize(self):
model_dict = self.model.serialize() # pytorch model to dict
Expand Down