Skip to content

Commit 5d885c7

Browse files
committed
Merge remote-tracking branch 'origin/devel' into avoid-deepcopy
2 parents ba498f9 + 6c66be9 commit 5d885c7

16 files changed

Lines changed: 413 additions & 63 deletions

File tree

deepmd/dpmodel/descriptor/hybrid.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ class DescrptHybrid(BaseDescriptor, NativeOP):
4343
def __init__(
4444
self,
4545
list: list[Union[BaseDescriptor, dict[str, Any]]],
46+
type_map: Optional[list[str]] = None,
47+
ntypes: Optional[int] = None, # to be compat with input
4648
) -> None:
4749
super().__init__()
4850
# warning: list is conflict with built-in list
@@ -56,6 +58,10 @@ def __init__(
5658
if isinstance(ii, BaseDescriptor):
5759
formatted_descript_list.append(ii)
5860
elif isinstance(ii, dict):
61+
ii = ii.copy()
62+
# only pass if not already set
63+
ii.setdefault("type_map", type_map)
64+
ii.setdefault("ntypes", ntypes)
5965
formatted_descript_list.append(BaseDescriptor(**ii))
6066
else:
6167
raise NotImplementedError

deepmd/dpmodel/descriptor/repformers.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1792,7 +1792,7 @@ def serialize(self) -> dict:
17921792
"""
17931793
data = {
17941794
"@class": "RepformerLayer",
1795-
"@version": 1,
1795+
"@version": 2,
17961796
"rcut": self.rcut,
17971797
"rcut_smth": self.rcut_smth,
17981798
"sel": self.sel,
@@ -1877,9 +1877,11 @@ def serialize(self) -> dict:
18771877
if self.update_style == "res_residual":
18781878
data.update(
18791879
{
1880-
"g1_residual": [to_numpy_array(aa) for aa in self.g1_residual],
1881-
"g2_residual": [to_numpy_array(aa) for aa in self.g2_residual],
1882-
"h2_residual": [to_numpy_array(aa) for aa in self.h2_residual],
1880+
"@variables": {
1881+
"g1_residual": [to_numpy_array(aa) for aa in self.g1_residual],
1882+
"g2_residual": [to_numpy_array(aa) for aa in self.g2_residual],
1883+
"h2_residual": [to_numpy_array(aa) for aa in self.h2_residual],
1884+
}
18831885
}
18841886
)
18851887
return data
@@ -1894,7 +1896,7 @@ def deserialize(cls, data: dict) -> "RepformerLayer":
18941896
The dict to deserialize from.
18951897
"""
18961898
data = data.copy()
1897-
check_version_compatibility(data.pop("@version"), 1, 1)
1899+
check_version_compatibility(data.pop("@version"), 2, 1)
18981900
data.pop("@class")
18991901
linear1 = data.pop("linear1")
19001902
update_chnnl_2 = data["update_chnnl_2"]
@@ -1915,9 +1917,10 @@ def deserialize(cls, data: dict) -> "RepformerLayer":
19151917
attn2_ev_apply = data.pop("attn2_ev_apply", None)
19161918
loc_attn = data.pop("loc_attn", None)
19171919
g1_self_mlp = data.pop("g1_self_mlp", None)
1918-
g1_residual = data.pop("g1_residual", [])
1919-
g2_residual = data.pop("g2_residual", [])
1920-
h2_residual = data.pop("h2_residual", [])
1920+
variables = data.pop("@variables", {})
1921+
g1_residual = variables.get("g1_residual", data.pop("g1_residual", []))
1922+
g2_residual = variables.get("g2_residual", data.pop("g2_residual", []))
1923+
h2_residual = variables.get("h2_residual", data.pop("h2_residual", []))
19211924

19221925
obj = cls(**data)
19231926
obj.linear1 = NativeLayer.deserialize(linear1)

deepmd/dpmodel/model/model.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,6 @@
88
from deepmd.dpmodel.descriptor.base_descriptor import (
99
BaseDescriptor,
1010
)
11-
from deepmd.dpmodel.descriptor.se_e2_a import (
12-
DescrptSeA,
13-
)
1411
from deepmd.dpmodel.fitting.ener_fitting import (
1512
EnergyFittingNet,
1613
)
@@ -39,16 +36,13 @@ def get_standard_model(data: dict) -> EnergyModel:
3936
data : dict
4037
The data to construct the model.
4138
"""
42-
descriptor_type = data["descriptor"].pop("type")
4339
data["descriptor"]["type_map"] = data["type_map"]
40+
data["descriptor"]["ntypes"] = len(data["type_map"])
4441
fitting_type = data["fitting_net"].pop("type")
4542
data["fitting_net"]["type_map"] = data["type_map"]
46-
if descriptor_type == "se_e2_a":
47-
descriptor = DescrptSeA(
48-
**data["descriptor"],
49-
)
50-
else:
51-
raise ValueError(f"Unknown descriptor type {descriptor_type}")
43+
descriptor = BaseDescriptor(
44+
**data["descriptor"],
45+
)
5246
if fitting_type == "ener":
5347
fitting = EnergyFittingNet(
5448
ntypes=descriptor.get_ntypes(),

deepmd/jax/descriptor/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
from deepmd.jax.descriptor.hybrid import (
99
DescrptHybrid,
1010
)
11+
from deepmd.jax.descriptor.se_atten_v2 import (
12+
DescrptSeAttenV2,
13+
)
1114
from deepmd.jax.descriptor.se_e2_a import (
1215
DescrptSeA,
1316
)
@@ -27,6 +30,7 @@
2730
"DescrptSeT",
2831
"DescrptSeTTebd",
2932
"DescrptDPA1",
33+
"DescrptSeAttenV2",
3034
"DescrptDPA2",
3135
"DescrptHybrid",
3236
]

deepmd/jax/model/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def get_standard_model(data: dict):
3737
data = deepcopy(data)
3838
descriptor_type = data["descriptor"].pop("type")
3939
data["descriptor"]["type_map"] = data["type_map"]
40+
data["descriptor"]["ntypes"] = len(data["type_map"])
4041
fitting_type = data["fitting_net"].pop("type")
4142
data["fitting_net"]["type_map"] = data["type_map"]
4243
descriptor = BaseDescriptor.get_class_by_type(descriptor_type)(

deepmd/pt/model/descriptor/repformer_layer.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1295,7 +1295,7 @@ def serialize(self) -> dict:
12951295
"""
12961296
data = {
12971297
"@class": "RepformerLayer",
1298-
"@version": 1,
1298+
"@version": 2,
12991299
"rcut": self.rcut,
13001300
"rcut_smth": self.rcut_smth,
13011301
"sel": self.sel,
@@ -1380,9 +1380,11 @@ def serialize(self) -> dict:
13801380
if self.update_style == "res_residual":
13811381
data.update(
13821382
{
1383-
"g1_residual": [to_numpy_array(t) for t in self.g1_residual],
1384-
"g2_residual": [to_numpy_array(t) for t in self.g2_residual],
1385-
"h2_residual": [to_numpy_array(t) for t in self.h2_residual],
1383+
"@variables": {
1384+
"g1_residual": [to_numpy_array(t) for t in self.g1_residual],
1385+
"g2_residual": [to_numpy_array(t) for t in self.g2_residual],
1386+
"h2_residual": [to_numpy_array(t) for t in self.h2_residual],
1387+
}
13861388
}
13871389
)
13881390
return data
@@ -1397,7 +1399,7 @@ def deserialize(cls, data: dict) -> "RepformerLayer":
13971399
The dict to deserialize from.
13981400
"""
13991401
data = data.copy()
1400-
check_version_compatibility(data.pop("@version"), 1, 1)
1402+
check_version_compatibility(data.pop("@version"), 2, 1)
14011403
data.pop("@class")
14021404
linear1 = data.pop("linear1")
14031405
update_chnnl_2 = data["update_chnnl_2"]
@@ -1418,9 +1420,10 @@ def deserialize(cls, data: dict) -> "RepformerLayer":
14181420
attn2_ev_apply = data.pop("attn2_ev_apply", None)
14191421
loc_attn = data.pop("loc_attn", None)
14201422
g1_self_mlp = data.pop("g1_self_mlp", None)
1421-
g1_residual = data.pop("g1_residual", [])
1422-
g2_residual = data.pop("g2_residual", [])
1423-
h2_residual = data.pop("h2_residual", [])
1423+
variables = data.pop("@variables", {})
1424+
g1_residual = variables.get("g1_residual", data.pop("g1_residual", []))
1425+
g2_residual = variables.get("g2_residual", data.pop("g2_residual", []))
1426+
h2_residual = variables.get("h2_residual", data.pop("h2_residual", []))
14241427

14251428
obj = cls(**data)
14261429
obj.linear1 = MLPLayer.deserialize(linear1)

deepmd/pt/model/task/fitting.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -418,8 +418,8 @@ def _forward_common(
418418

419419
if nd != self.dim_descrpt:
420420
raise ValueError(
421-
"get an input descriptor of dim {nd},"
422-
"which is not consistent with {self.dim_descrpt}."
421+
f"get an input descriptor of dim {nd},"
422+
f"which is not consistent with {self.dim_descrpt}."
423423
)
424424
# check fparam dim, concate to input descriptor
425425
if self.numb_fparam > 0:

deepmd/tf/descriptor/hybrid.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,8 @@ def update_sel(
461461
return local_jdata_cpy, min_nbor_dist
462462

463463
def serialize(self, suffix: str = "") -> dict:
464+
if hasattr(self, "type_embedding"):
465+
raise NotImplementedError("hybrid + type embedding is not supported")
464466
return {
465467
"@class": "Descriptor",
466468
"type": "hybrid",
@@ -485,4 +487,8 @@ def deserialize(cls, data: dict, suffix: str = "") -> "DescrptHybrid":
485487
for idx, ii in enumerate(data["list"])
486488
],
487489
)
490+
# search for type embedding
491+
for ii in obj.descrpt_list:
492+
if hasattr(ii, "type_embedding"):
493+
raise NotImplementedError("hybrid + type embedding is not supported")
488494
return obj

0 commit comments

Comments
 (0)