Skip to content

Commit cfde37b

Browse files
committed
clean compat v1 -> v2 implemented
1 parent c6ae183 commit cfde37b

2 files changed

Lines changed: 66 additions & 26 deletions

File tree

deepmd/entrypoints/train.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from deepmd.train.run_options import BUILD, CITATION, WELCOME, RunOptions
1616
from deepmd.train.trainer import DPTrainer
1717
from deepmd.utils.argcheck import normalize
18-
from deepmd.utils.compat import convert_input_v0_v1
18+
from deepmd.utils.compat import updata_deepmd_input
1919
from deepmd.utils.data_system import DeepmdDataSystem
2020

2121
if TYPE_CHECKING:
@@ -168,8 +168,7 @@ def train(
168168
# load json database
169169
jdata = j_loader(INPUT)
170170

171-
if "model" not in jdata.keys():
172-
jdata = convert_input_v0_v1(jdata, warning=True, dump="input_v1_compat.json")
171+
jdata = updata_deepmd_input(jdata, warning=True, dump="input_v2_compat.json")
173172

174173
jdata = normalize(jdata)
175174
with open(output, "w") as fp:

deepmd/utils/compat.py

Lines changed: 64 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ def convert_input_v0_v1(
1212
jdata: Dict[str, Any], warning: bool = True, dump: Optional[Union[str, Path]] = None
1313
) -> Dict[str, Any]:
1414
"""Convert input from v0 format to v1.
15-
1615
Parameters
1716
----------
1817
jdata : Dict[str, Any]
@@ -21,12 +20,12 @@ def convert_input_v0_v1(
2120
whether to show deprecation warning, by default True
2221
dump : Optional[Union[str, Path]], optional
2322
whether to dump converted file, by default None
24-
2523
Returns
2624
-------
2725
Dict[str, Any]
2826
converted output
2927
"""
28+
3029
output = {}
3130
if "with_distrib" in jdata:
3231
output["with_distrib"] = jdata["with_distrib"]
@@ -35,33 +34,29 @@ def convert_input_v0_v1(
3534
output["loss"] = _loss(jdata)
3635
output["training"] = _training(jdata)
3736
if warning:
38-
_warnning_input_v0_v1(dump)
37+
_warning_input_v0_v1(dump)
3938
if dump is not None:
4039
with open(dump, "w") as fp:
4140
json.dump(output, fp, indent=4)
4241
return output
4342

4443

45-
def _warnning_input_v0_v1(fname: Optional[Union[str, Path]]):
46-
msg = (
47-
"It seems that you are using a deepmd-kit input of version 0.x.x, "
48-
"which is deprecated. we have converted the input to >1.0.0 compatible"
49-
)
44+
def _warning_input_v0_v1(fname: Optional[Union[str, Path]]):
45+
msg = "It seems that you are using a deepmd-kit input of version 0.x.x, " \
46+
"which is deprecated. we have converted the input to >2.0.0 compatible"
5047
if fname is not None:
5148
msg += f", and output it to file {fname}"
5249
warnings.warn(msg)
5350

5451

5552
def _model(jdata: Dict[str, Any], smooth: bool) -> Dict[str, Dict[str, Any]]:
5653
"""Convert data to v1 input for non-smooth model.
57-
5854
Parameters
5955
----------
6056
jdata : Dict[str, Any]
6157
parsed input json/yaml data
6258
smooth : bool
6359
whether to use smooth or non-smooth descriptor version
64-
6560
Returns
6661
-------
6762
Dict[str, Dict[str, Any]]
@@ -78,12 +73,10 @@ def _model(jdata: Dict[str, Any], smooth: bool) -> Dict[str, Dict[str, Any]]:
7873

7974
def _nonsmth_descriptor(jdata: Dict[str, Any]) -> Dict[str, Any]:
8075
"""Convert data to v1 input for non-smooth descriptor.
81-
8276
Parameters
8377
----------
8478
jdata : Dict[str, Any]
8579
parsed input json/yaml data
86-
8780
Returns
8881
-------
8982
Dict[str, Any]
@@ -97,12 +90,10 @@ def _nonsmth_descriptor(jdata: Dict[str, Any]) -> Dict[str, Any]:
9790

9891
def _smth_descriptor(jdata: Dict[str, Any]) -> Dict[str, Any]:
9992
"""Convert data to v1 input for smooth descriptor.
100-
10193
Parameters
10294
----------
10395
jdata : Dict[str, Any]
10496
parsed input json/yaml data
105-
10697
Returns
10798
-------
10899
Dict[str, Any]
@@ -127,12 +118,10 @@ def _smth_descriptor(jdata: Dict[str, Any]) -> Dict[str, Any]:
127118

128119
def _fitting_net(jdata: Dict[str, Any]) -> Dict[str, Any]:
129120
"""Convert data to v1 input for fitting net.
130-
131121
Parameters
132122
----------
133123
jdata : Dict[str, Any]
134124
parsed input json/yaml data
135-
136125
Returns
137126
-------
138127
Dict[str, Any]
@@ -154,12 +143,10 @@ def _fitting_net(jdata: Dict[str, Any]) -> Dict[str, Any]:
154143

155144
def _learning_rate(jdata: Dict[str, Any]) -> Dict[str, Any]:
156145
"""Convert data to v1 input for learning rate section.
157-
158146
Parameters
159147
----------
160148
jdata : Dict[str, Any]
161149
parsed input json/yaml data
162-
163150
Returns
164151
-------
165152
Dict[str, Any]
@@ -173,12 +160,10 @@ def _learning_rate(jdata: Dict[str, Any]) -> Dict[str, Any]:
173160

174161
def _loss(jdata: Dict[str, Any]) -> Dict[str, Any]:
175162
"""Convert data to v1 input for loss function.
176-
177163
Parameters
178164
----------
179165
jdata : Dict[str, Any]
180166
parsed input json/yaml data
181-
182167
Returns
183168
-------
184169
Dict[str, Any]
@@ -206,12 +191,10 @@ def _loss(jdata: Dict[str, Any]) -> Dict[str, Any]:
206191

207192
def _training(jdata: Dict[str, Any]) -> Dict[str, Any]:
208193
"""Convert data to v1 input for training.
209-
210194
Parameters
211195
----------
212196
jdata : Dict[str, Any]
213197
parsed input json/yaml data
214-
215198
Returns
216199
-------
217200
Dict[str, Any]
@@ -241,7 +224,6 @@ def _training(jdata: Dict[str, Any]) -> Dict[str, Any]:
241224

242225
def _jcopy(src: Dict[str, Any], dst: Dict[str, Any], keys: Sequence[str]):
243226
"""Copy specified keys from one dict to another.
244-
245227
Parameters
246228
----------
247229
src : Dict[str, Any]
@@ -255,3 +237,62 @@ def _jcopy(src: Dict[str, Any], dst: Dict[str, Any], keys: Sequence[str]):
255237
"""
256238
for k in keys:
257239
dst[k] = src[k]
240+
241+
242+
def convert_input_v1_v2(jdata: Dict[str, Any],
243+
warning: bool = True,
244+
dump: Optional[Union[str, Path]] = None) -> Dict[str, Any]:
245+
246+
tr_cfg = jdata["training"]
247+
tr_data_keys = {
248+
"systems",
249+
"set_prefix",
250+
"batch_size",
251+
"sys_prob",
252+
"auto_prob",
253+
# alias included
254+
"sys_weights",
255+
"auto_prob_style"
256+
}
257+
258+
tr_data_cfg = {k: v for k, v in tr_cfg.items() if k in tr_data_keys}
259+
new_tr_cfg = {k: v for k, v in tr_cfg.items() if k not in tr_data_keys}
260+
new_tr_cfg["training_data"] = tr_data_cfg
261+
262+
jdata["training"] = new_tr_cfg
263+
264+
if warning:
265+
_warning_input_v1_v2(dump)
266+
if dump is not None:
267+
with open(dump, "w") as fp:
268+
json.dump(jdata, fp, indent=4)
269+
270+
return jdata
271+
272+
273+
def _warning_input_v1_v2(fname: Optional[Union[str, Path]]):
274+
msg = "It seems that you are using a deepmd-kit input of version 1.x.x, " \
275+
"which is deprecated. we have converted the input to >2.0.0 compatible"
276+
if fname is not None:
277+
msg += f", and output it to file {fname}"
278+
warnings.warn(msg)
279+
280+
281+
def updata_deepmd_input(jdata: Dict[str, Any],
282+
warning: bool = True,
283+
dump: Optional[Union[str, Path]] = None) -> Dict[str, Any]:
284+
def is_deepmd_v0_input(jdata):
285+
return "model" not in jdata.keys()
286+
287+
def is_deepmd_v1_input(jdata):
288+
return "systems" in j_must_have(jdata, "training").keys()
289+
290+
if is_deepmd_v0_input(jdata):
291+
jdata = convert_input_v0_v1(jdata, warning, None)
292+
jdata = convert_input_v1_v2(jdata, False, dump)
293+
elif is_deepmd_v1_input(jdata):
294+
jdata = convert_input_v1_v2(jdata, warning, dump)
295+
else:
296+
pass
297+
298+
return jdata

0 commit comments

Comments
 (0)