Skip to content

Commit da9b526

Browse files
authored
feat(pt): consistent "frozen" model (#3450)
This PR is based on #3449, as the test needs #3449 to pass. Add a consistent `frozen` model in pt. Both TF and PT now support using models in any format. --------- Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
1 parent 9bcae14 commit da9b526

9 files changed

Lines changed: 387 additions & 5 deletions

File tree

deepmd/dpmodel/utils/network.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,10 @@ def deserialize(cls, data: dict) -> "NativeLayer":
230230
variables.get("b", None),
231231
variables.get("idt", None),
232232
)
233+
if obj.b is not None:
234+
obj.b = obj.b.ravel()
235+
if obj.idt is not None:
236+
obj.idt = obj.idt.ravel()
233237
obj.check_shape_consistency()
234238
return obj
235239

deepmd/pt/model/model/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@
3737
from .ener_model import (
3838
EnergyModel,
3939
)
40+
from .frozen import (
41+
FrozenModel,
42+
)
4043
from .make_hessian_model import (
4144
make_hessian_model,
4245
)
@@ -173,6 +176,7 @@ def get_model(model_params):
173176
"get_model",
174177
"DPModel",
175178
"EnergyModel",
179+
"FrozenModel",
176180
"SpinModel",
177181
"SpinEnergyModel",
178182
"DPZBLModel",

deepmd/pt/model/model/frozen.py

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import json
3+
import tempfile
4+
from typing import (
5+
Dict,
6+
List,
7+
Optional,
8+
)
9+
10+
import torch
11+
12+
from deepmd.dpmodel.output_def import (
13+
FittingOutputDef,
14+
)
15+
from deepmd.entrypoints.convert_backend import (
16+
convert_backend,
17+
)
18+
from deepmd.pt.model.model.model import (
19+
BaseModel,
20+
)
21+
22+
23+
@BaseModel.register("frozen")
24+
class FrozenModel(BaseModel):
25+
"""Load model from a frozen model, which cannot be trained.
26+
27+
Parameters
28+
----------
29+
model_file : str
30+
The path to the frozen model
31+
"""
32+
33+
def __init__(self, model_file: str, **kwargs):
34+
super().__init__(**kwargs)
35+
self.model_file = model_file
36+
if model_file.endswith(".pth"):
37+
self.model = torch.jit.load(model_file)
38+
else:
39+
# try to convert from other formats
40+
with tempfile.NamedTemporaryFile(suffix=".pth") as f:
41+
convert_backend(INPUT=model_file, OUTPUT=f.name)
42+
self.model = torch.jit.load(f.name)
43+
44+
@torch.jit.export
45+
def fitting_output_def(self) -> FittingOutputDef:
46+
"""Get the output def of developer implemented atomic models."""
47+
return self.model.fitting_output_def()
48+
49+
@torch.jit.export
50+
def get_rcut(self) -> float:
51+
"""Get the cut-off radius."""
52+
return self.model.get_rcut()
53+
54+
@torch.jit.export
55+
def get_type_map(self) -> List[str]:
56+
"""Get the type map."""
57+
return self.model.get_type_map()
58+
59+
@torch.jit.export
60+
def get_sel(self) -> List[int]:
61+
"""Returns the number of selected atoms for each type."""
62+
return self.model.get_sel()
63+
64+
@torch.jit.export
65+
def get_dim_fparam(self) -> int:
66+
"""Get the number (dimension) of frame parameters of this atomic model."""
67+
return self.model.get_dim_fparam()
68+
69+
@torch.jit.export
70+
def get_dim_aparam(self) -> int:
71+
"""Get the number (dimension) of atomic parameters of this atomic model."""
72+
return self.model.get_dim_aparam()
73+
74+
@torch.jit.export
75+
def get_sel_type(self) -> List[int]:
76+
"""Get the selected atom types of this model.
77+
78+
Only atoms with selected atom types have atomic contribution
79+
to the result of the model.
80+
If returning an empty list, all atom types are selected.
81+
"""
82+
return self.model.get_sel_type()
83+
84+
@torch.jit.export
85+
def is_aparam_nall(self) -> bool:
86+
"""Check whether the shape of atomic parameters is (nframes, nall, ndim).
87+
88+
If False, the shape is (nframes, nloc, ndim).
89+
"""
90+
return self.model.is_aparam_nall()
91+
92+
@torch.jit.export
93+
def mixed_types(self) -> bool:
94+
"""If true, the model
95+
1. assumes total number of atoms aligned across frames;
96+
2. uses a neighbor list that does not distinguish different atomic types.
97+
98+
If false, the model
99+
1. assumes total number of atoms of each atom type aligned across frames;
100+
2. uses a neighbor list that distinguishes different atomic types.
101+
102+
"""
103+
return self.model.mixed_types()
104+
105+
@torch.jit.export
106+
def forward(
107+
self,
108+
coord,
109+
atype,
110+
box: Optional[torch.Tensor] = None,
111+
fparam: Optional[torch.Tensor] = None,
112+
aparam: Optional[torch.Tensor] = None,
113+
do_atomic_virial: bool = False,
114+
) -> Dict[str, torch.Tensor]:
115+
return self.model.forward(
116+
coord,
117+
atype,
118+
box=box,
119+
fparam=fparam,
120+
aparam=aparam,
121+
do_atomic_virial=do_atomic_virial,
122+
)
123+
124+
@torch.jit.export
125+
def get_model_def_script(self) -> str:
126+
"""Get the model definition script."""
127+
# try to use the original script instead of "frozen model"
128+
# Note: this cannot change the script of the parent model
129+
# it may still try to load hard-coded filename, which might
130+
# be a problem
131+
return self.model.get_model_def_script()
132+
133+
def serialize(self) -> dict:
134+
from deepmd.pt.model.model import (
135+
get_model,
136+
)
137+
138+
# try to recover the original model
139+
model_def_script = json.loads(self.get_model_def_script())
140+
model = get_model(model_def_script)
141+
model.load_state_dict(self.model.state_dict())
142+
return model.serialize()
143+
144+
@classmethod
145+
def deserialize(cls, data: dict):
146+
raise RuntimeError("Should not touch here.")
147+
148+
@torch.jit.export
149+
def get_nnei(self) -> int:
150+
"""Returns the total number of selected neighboring atoms in the cut-off radius."""
151+
return self.model.get_nnei()
152+
153+
@torch.jit.export
154+
def get_nsel(self) -> int:
155+
"""Returns the total number of selected neighboring atoms in the cut-off radius."""
156+
return self.model.get_nsel()
157+
158+
@classmethod
159+
def update_sel(cls, global_jdata: dict, local_jdata: dict):
160+
"""Update the selection and perform neighbor statistics.
161+
162+
Parameters
163+
----------
164+
global_jdata : dict
165+
The global data, containing the training section
166+
local_jdata : dict
167+
The local data refer to the current class
168+
"""
169+
return local_jdata
170+
171+
@torch.jit.export
172+
def model_output_type(self) -> str:
173+
"""Get the output type for the model."""
174+
return self.model.model_output_type()

deepmd/tf/fit/ener.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -868,7 +868,7 @@ def deserialize(cls, data: dict, suffix: str = ""):
868868
data["nets"],
869869
suffix=suffix,
870870
)
871-
fitting.bias_atom_e = data["@variables"]["bias_atom_e"]
871+
fitting.bias_atom_e = data["@variables"]["bias_atom_e"].ravel()
872872
if fitting.numb_fparam > 0:
873873
fitting.fparam_avg = data["@variables"]["fparam_avg"]
874874
fitting.fparam_inv_std = data["@variables"]["fparam_inv_std"]
@@ -922,7 +922,7 @@ def serialize(self, suffix: str = "") -> dict:
922922
suffix=suffix,
923923
),
924924
"@variables": {
925-
"bias_atom_e": self.bias_atom_e,
925+
"bias_atom_e": self.bias_atom_e.reshape(-1, 1),
926926
"fparam_avg": self.fparam_avg,
927927
"fparam_inv_std": self.fparam_inv_std,
928928
"aparam_avg": self.aparam_avg,

deepmd/tf/model/frozen.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import json
3+
import os
4+
import tempfile
25
from enum import (
36
Enum,
47
)
@@ -7,6 +10,9 @@
710
Union,
811
)
912

13+
from deepmd.entrypoints.convert_backend import (
14+
convert_backend,
15+
)
1016
from deepmd.infer.deep_pot import (
1117
DeepPot,
1218
)
@@ -24,6 +30,10 @@
2430
from deepmd.tf.loss.loss import (
2531
Loss,
2632
)
33+
from deepmd.tf.utils.graph import (
34+
get_tensor_by_name_from_graph,
35+
load_graph_def,
36+
)
2737

2838
from .model import (
2939
Model,
@@ -43,7 +53,14 @@ class FrozenModel(Model):
4353
def __init__(self, model_file: str, **kwargs):
4454
super().__init__(**kwargs)
4555
self.model_file = model_file
46-
self.model = DeepPotential(model_file)
56+
if not model_file.endswith(".pb"):
57+
# try to convert from other formats
58+
with tempfile.NamedTemporaryFile(
59+
suffix=".pb", dir=os.curdir, delete=False
60+
) as f:
61+
convert_backend(INPUT=model_file, OUTPUT=f.name)
62+
self.model_file = f.name
63+
self.model = DeepPotential(self.model_file)
4764
if isinstance(self.model, DeepPot):
4865
self.model_type = "ener"
4966
else:
@@ -228,3 +245,19 @@ def update_sel(cls, global_jdata: dict, local_jdata: dict):
228245
"""
229246
# we don't know how to compress it, so no neighbor statistics here
230247
return local_jdata
248+
249+
def serialize(self, suffix: str = "") -> dict:
250+
# try to recover the original model
251+
# the current graph contains a prefix "load",
252+
# so it cannot used to recover the original model
253+
graph, graph_def = load_graph_def(self.model_file)
254+
t_jdata = get_tensor_by_name_from_graph(graph, "train_attr/training_script")
255+
jdata = json.loads(t_jdata)
256+
model = Model(**jdata["model"])
257+
# important! must be called before serialize
258+
model.init_variables(graph=graph, graph_def=graph_def)
259+
return model.serialize()
260+
261+
@classmethod
262+
def deserialize(cls, data: dict, suffix: str = ""):
263+
raise RuntimeError("Should not touch here.")

deepmd/tf/model/model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -566,7 +566,8 @@ def deserialize(cls, data: dict, suffix: str = "") -> "Model":
566566
"""
567567
if cls is Model:
568568
return Model.get_class_by_type(data.get("type", "standard")).deserialize(
569-
data
569+
data,
570+
suffix=suffix,
570571
)
571572
raise NotImplementedError("Not implemented in class %s" % cls.__name__)
572573

deepmd/utils/argcheck.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1461,7 +1461,6 @@ def frozen_model_args() -> Argument:
14611461
[
14621462
Argument("model_file", str, optional=False, doc=doc_model_file),
14631463
],
1464-
doc=doc_only_tf_supported,
14651464
)
14661465
return ca
14671466

0 commit comments

Comments
 (0)