|
6 | 6 |
|
7 | 7 | import copy |
8 | 8 | import itertools |
9 | | -import json |
10 | | -from datetime import ( |
11 | | - datetime, |
12 | | -) |
13 | 9 | from typing import ( |
14 | 10 | Callable, |
15 | 11 | ClassVar, |
|
19 | 15 | Union, |
20 | 16 | ) |
21 | 17 |
|
22 | | -import h5py |
23 | 18 | import numpy as np |
24 | 19 |
|
25 | 20 | from deepmd.utils.version import ( |
|
38 | 33 | ) |
39 | 34 |
|
40 | 35 |
|
41 | | -def traverse_model_dict(model_obj, callback: callable, is_variable: bool = False): |
42 | | - """Traverse a model dict and call callback on each variable. |
43 | | -
|
44 | | - Parameters |
45 | | - ---------- |
46 | | - model_obj : object |
47 | | - The model object to traverse. |
48 | | - callback : callable |
49 | | - The callback function to call on each variable. |
50 | | - is_variable : bool, optional |
51 | | - Whether the current node is a variable. |
52 | | -
|
53 | | - Returns |
54 | | - ------- |
55 | | - object |
56 | | - The model object after traversing. |
57 | | - """ |
58 | | - if isinstance(model_obj, dict): |
59 | | - for kk, vv in model_obj.items(): |
60 | | - model_obj[kk] = traverse_model_dict( |
61 | | - vv, callback, is_variable=is_variable or kk == "@variables" |
62 | | - ) |
63 | | - elif isinstance(model_obj, list): |
64 | | - for ii, vv in enumerate(model_obj): |
65 | | - model_obj[ii] = traverse_model_dict(vv, callback, is_variable=is_variable) |
66 | | - elif model_obj is None: |
67 | | - return model_obj |
68 | | - elif is_variable: |
69 | | - model_obj = callback(model_obj) |
70 | | - return model_obj |
71 | | - |
72 | | - |
73 | | -class Counter: |
74 | | - """A callable counter. |
75 | | -
|
76 | | - Examples |
77 | | - -------- |
78 | | - >>> counter = Counter() |
79 | | - >>> counter() |
80 | | - 0 |
81 | | - >>> counter() |
82 | | - 1 |
83 | | - """ |
84 | | - |
85 | | - def __init__(self): |
86 | | - self.count = -1 |
87 | | - |
88 | | - def __call__(self): |
89 | | - self.count += 1 |
90 | | - return self.count |
91 | | - |
92 | | - |
93 | | -# TODO: move save_dp_model and load_dp_model to a seperated module |
94 | | -# should be moved to otherwhere... |
95 | | -def save_dp_model(filename: str, model_dict: dict) -> None: |
96 | | - """Save a DP model to a file in the native format. |
97 | | -
|
98 | | - Parameters |
99 | | - ---------- |
100 | | - filename : str |
101 | | - The filename to save to. |
102 | | - model_dict : dict |
103 | | - The model dict to save. |
104 | | - """ |
105 | | - model_dict = model_dict.copy() |
106 | | - variable_counter = Counter() |
107 | | - with h5py.File(filename, "w") as f: |
108 | | - model_dict = traverse_model_dict( |
109 | | - model_dict, |
110 | | - lambda x: f.create_dataset( |
111 | | - f"variable_{variable_counter():04d}", data=x |
112 | | - ).name, |
113 | | - ) |
114 | | - save_dict = { |
115 | | - "software": "deepmd-kit", |
116 | | - "version": __version__, |
117 | | - # use UTC+0 time |
118 | | - "time": str(datetime.utcnow()), |
119 | | - **model_dict, |
120 | | - } |
121 | | - f.attrs["json"] = json.dumps(save_dict, separators=(",", ":")) |
122 | | - |
123 | | - |
124 | | -def load_dp_model(filename: str) -> dict: |
125 | | - """Load a DP model from a file in the native format. |
126 | | -
|
127 | | - Parameters |
128 | | - ---------- |
129 | | - filename : str |
130 | | - The filename to load from. |
131 | | -
|
132 | | - Returns |
133 | | - ------- |
134 | | - dict |
135 | | - The loaded model dict, including meta information. |
136 | | - """ |
137 | | - with h5py.File(filename, "r") as f: |
138 | | - model_dict = json.loads(f.attrs["json"]) |
139 | | - model_dict = traverse_model_dict(model_dict, lambda x: f[x][()].copy()) |
140 | | - return model_dict |
141 | | - |
142 | | - |
143 | 36 | class NativeLayer(NativeOP): |
144 | 37 | """Native representation of a layer. |
145 | 38 |
|
|
0 commit comments