Skip to content

Commit 8d35535

Browse files
authored
chore(dpmodel): move save_dp_model and load_dp_model to a seperated module (#3701)
Fix #3526. Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
1 parent 86b0bf8 commit 8d35535

5 files changed

Lines changed: 123 additions & 115 deletions

File tree

deepmd/backend/dpmodel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def serialize_hook(self) -> Callable[[str], dict]:
100100
Callable[[str], dict]
101101
The serialize hook of the backend.
102102
"""
103-
from deepmd.dpmodel.utils.network import (
103+
from deepmd.dpmodel.utils.serialization import (
104104
load_dp_model,
105105
)
106106

@@ -115,7 +115,7 @@ def deserialize_hook(self) -> Callable[[str, dict], None]:
115115
Callable[[str, dict], None]
116116
The deserialize hook of the backend.
117117
"""
118-
from deepmd.dpmodel.utils.network import (
118+
from deepmd.dpmodel.utils.serialization import (
119119
save_dp_model,
120120
)
121121

deepmd/dpmodel/infer/deep_eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from deepmd.dpmodel.utils.batch_size import (
2525
AutoBatchSize,
2626
)
27-
from deepmd.dpmodel.utils.network import (
27+
from deepmd.dpmodel.utils.serialization import (
2828
load_dp_model,
2929
)
3030
from deepmd.infer.deep_dipole import (

deepmd/dpmodel/utils/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,9 @@
1212
NativeLayer,
1313
NativeNet,
1414
NetworkCollection,
15-
load_dp_model,
1615
make_embedding_network,
1716
make_fitting_network,
1817
make_multilayer_network,
19-
save_dp_model,
20-
traverse_model_dict,
2118
)
2219
from .nlist import (
2320
build_multiple_neighbor_list,
@@ -32,6 +29,11 @@
3229
phys2inter,
3330
to_face_distance,
3431
)
32+
from .serialization import (
33+
load_dp_model,
34+
save_dp_model,
35+
traverse_model_dict,
36+
)
3537

3638
__all__ = [
3739
"EnvMat",
@@ -46,8 +48,6 @@
4648
"load_dp_model",
4749
"save_dp_model",
4850
"traverse_model_dict",
49-
"PRECISION_DICT",
50-
"DEFAULT_PRECISION",
5151
"build_neighbor_list",
5252
"nlist_distinguish_types",
5353
"get_multiple_nlist_key",

deepmd/dpmodel/utils/network.py

Lines changed: 0 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,6 @@
66

77
import copy
88
import itertools
9-
import json
10-
from datetime import (
11-
datetime,
12-
)
139
from typing import (
1410
Callable,
1511
ClassVar,
@@ -19,7 +15,6 @@
1915
Union,
2016
)
2117

22-
import h5py
2318
import numpy as np
2419

2520
from deepmd.utils.version import (
@@ -38,108 +33,6 @@
3833
)
3934

4035

41-
def traverse_model_dict(model_obj, callback: callable, is_variable: bool = False):
42-
"""Traverse a model dict and call callback on each variable.
43-
44-
Parameters
45-
----------
46-
model_obj : object
47-
The model object to traverse.
48-
callback : callable
49-
The callback function to call on each variable.
50-
is_variable : bool, optional
51-
Whether the current node is a variable.
52-
53-
Returns
54-
-------
55-
object
56-
The model object after traversing.
57-
"""
58-
if isinstance(model_obj, dict):
59-
for kk, vv in model_obj.items():
60-
model_obj[kk] = traverse_model_dict(
61-
vv, callback, is_variable=is_variable or kk == "@variables"
62-
)
63-
elif isinstance(model_obj, list):
64-
for ii, vv in enumerate(model_obj):
65-
model_obj[ii] = traverse_model_dict(vv, callback, is_variable=is_variable)
66-
elif model_obj is None:
67-
return model_obj
68-
elif is_variable:
69-
model_obj = callback(model_obj)
70-
return model_obj
71-
72-
73-
class Counter:
74-
"""A callable counter.
75-
76-
Examples
77-
--------
78-
>>> counter = Counter()
79-
>>> counter()
80-
0
81-
>>> counter()
82-
1
83-
"""
84-
85-
def __init__(self):
86-
self.count = -1
87-
88-
def __call__(self):
89-
self.count += 1
90-
return self.count
91-
92-
93-
# TODO: move save_dp_model and load_dp_model to a seperated module
94-
# should be moved to otherwhere...
95-
def save_dp_model(filename: str, model_dict: dict) -> None:
96-
"""Save a DP model to a file in the native format.
97-
98-
Parameters
99-
----------
100-
filename : str
101-
The filename to save to.
102-
model_dict : dict
103-
The model dict to save.
104-
"""
105-
model_dict = model_dict.copy()
106-
variable_counter = Counter()
107-
with h5py.File(filename, "w") as f:
108-
model_dict = traverse_model_dict(
109-
model_dict,
110-
lambda x: f.create_dataset(
111-
f"variable_{variable_counter():04d}", data=x
112-
).name,
113-
)
114-
save_dict = {
115-
"software": "deepmd-kit",
116-
"version": __version__,
117-
# use UTC+0 time
118-
"time": str(datetime.utcnow()),
119-
**model_dict,
120-
}
121-
f.attrs["json"] = json.dumps(save_dict, separators=(",", ":"))
122-
123-
124-
def load_dp_model(filename: str) -> dict:
125-
"""Load a DP model from a file in the native format.
126-
127-
Parameters
128-
----------
129-
filename : str
130-
The filename to load from.
131-
132-
Returns
133-
-------
134-
dict
135-
The loaded model dict, including meta information.
136-
"""
137-
with h5py.File(filename, "r") as f:
138-
model_dict = json.loads(f.attrs["json"])
139-
model_dict = traverse_model_dict(model_dict, lambda x: f[x][()].copy())
140-
return model_dict
141-
142-
14336
class NativeLayer(NativeOP):
14437
"""Native representation of a layer.
14538
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import json
3+
from datetime import (
4+
datetime,
5+
)
6+
from typing import (
7+
Callable,
8+
)
9+
10+
import h5py
11+
12+
try:
13+
from deepmd._version import version as __version__
14+
except ImportError:
15+
__version__ = "unknown"
16+
17+
18+
def traverse_model_dict(model_obj, callback: Callable, is_variable: bool = False):
19+
"""Traverse a model dict and call callback on each variable.
20+
21+
Parameters
22+
----------
23+
model_obj : object
24+
The model object to traverse.
25+
callback : callable
26+
The callback function to call on each variable.
27+
is_variable : bool, optional
28+
Whether the current node is a variable.
29+
30+
Returns
31+
-------
32+
object
33+
The model object after traversing.
34+
"""
35+
if isinstance(model_obj, dict):
36+
for kk, vv in model_obj.items():
37+
model_obj[kk] = traverse_model_dict(
38+
vv, callback, is_variable=is_variable or kk == "@variables"
39+
)
40+
elif isinstance(model_obj, list):
41+
for ii, vv in enumerate(model_obj):
42+
model_obj[ii] = traverse_model_dict(vv, callback, is_variable=is_variable)
43+
elif model_obj is None:
44+
return model_obj
45+
elif is_variable:
46+
model_obj = callback(model_obj)
47+
return model_obj
48+
49+
50+
class Counter:
51+
"""A callable counter.
52+
53+
Examples
54+
--------
55+
>>> counter = Counter()
56+
>>> counter()
57+
0
58+
>>> counter()
59+
1
60+
"""
61+
62+
def __init__(self):
63+
self.count = -1
64+
65+
def __call__(self):
66+
self.count += 1
67+
return self.count
68+
69+
70+
def save_dp_model(filename: str, model_dict: dict) -> None:
71+
"""Save a DP model to a file in the native format.
72+
73+
Parameters
74+
----------
75+
filename : str
76+
The filename to save to.
77+
model_dict : dict
78+
The model dict to save.
79+
"""
80+
model_dict = model_dict.copy()
81+
variable_counter = Counter()
82+
with h5py.File(filename, "w") as f:
83+
model_dict = traverse_model_dict(
84+
model_dict,
85+
lambda x: f.create_dataset(
86+
f"variable_{variable_counter():04d}", data=x
87+
).name,
88+
)
89+
save_dict = {
90+
"software": "deepmd-kit",
91+
"version": __version__,
92+
# use UTC+0 time
93+
"time": str(datetime.utcnow()),
94+
**model_dict,
95+
}
96+
f.attrs["json"] = json.dumps(save_dict, separators=(",", ":"))
97+
98+
99+
def load_dp_model(filename: str) -> dict:
100+
"""Load a DP model from a file in the native format.
101+
102+
Parameters
103+
----------
104+
filename : str
105+
The filename to load from.
106+
107+
Returns
108+
-------
109+
dict
110+
The loaded model dict, including meta information.
111+
"""
112+
with h5py.File(filename, "r") as f:
113+
model_dict = json.loads(f.attrs["json"])
114+
model_dict = traverse_model_dict(model_dict, lambda x: f[x][()].copy())
115+
return model_dict

0 commit comments

Comments
 (0)