@@ -12,7 +12,6 @@ def convert_input_v0_v1(
1212 jdata : Dict [str , Any ], warning : bool = True , dump : Optional [Union [str , Path ]] = None
1313) -> Dict [str , Any ]:
1414 """Convert input from v0 format to v1.
15-
1615 Parameters
1716 ----------
1817 jdata : Dict[str, Any]
@@ -21,12 +20,12 @@ def convert_input_v0_v1(
2120 whether to show deprecation warning, by default True
2221 dump : Optional[Union[str, Path]], optional
2322 whether to dump converted file, by default None
24-
2523 Returns
2624 -------
2725 Dict[str, Any]
2826 converted output
2927 """
28+
3029 output = {}
3130 if "with_distrib" in jdata :
3231 output ["with_distrib" ] = jdata ["with_distrib" ]
@@ -35,33 +34,29 @@ def convert_input_v0_v1(
3534 output ["loss" ] = _loss (jdata )
3635 output ["training" ] = _training (jdata )
3736 if warning :
38- _warnning_input_v0_v1 (dump )
37+ _warning_input_v0_v1 (dump )
3938 if dump is not None :
4039 with open (dump , "w" ) as fp :
4140 json .dump (output , fp , indent = 4 )
4241 return output
4342
4443
45- def _warnning_input_v0_v1 (fname : Optional [Union [str , Path ]]):
46- msg = (
47- "It seems that you are using a deepmd-kit input of version 0.x.x, "
48- "which is deprecated. we have converted the input to >1.0.0 compatible"
49- )
44+ def _warning_input_v0_v1 (fname : Optional [Union [str , Path ]]):
45+ msg = "It seems that you are using a deepmd-kit input of version 0.x.x, " \
46+ "which is deprecated. we have converted the input to >2.0.0 compatible"
5047 if fname is not None :
5148 msg += f", and output it to file { fname } "
5249 warnings .warn (msg )
5350
5451
5552def _model (jdata : Dict [str , Any ], smooth : bool ) -> Dict [str , Dict [str , Any ]]:
5653 """Convert data to v1 input for non-smooth model.
57-
5854 Parameters
5955 ----------
6056 jdata : Dict[str, Any]
6157 parsed input json/yaml data
6258 smooth : bool
6359 whether to use smooth or non-smooth descriptor version
64-
6560 Returns
6661 -------
6762 Dict[str, Dict[str, Any]]
@@ -78,12 +73,10 @@ def _model(jdata: Dict[str, Any], smooth: bool) -> Dict[str, Dict[str, Any]]:
7873
7974def _nonsmth_descriptor (jdata : Dict [str , Any ]) -> Dict [str , Any ]:
8075 """Convert data to v1 input for non-smooth descriptor.
81-
8276 Parameters
8377 ----------
8478 jdata : Dict[str, Any]
8579 parsed input json/yaml data
86-
8780 Returns
8881 -------
8982 Dict[str, Any]
@@ -97,12 +90,10 @@ def _nonsmth_descriptor(jdata: Dict[str, Any]) -> Dict[str, Any]:
9790
9891def _smth_descriptor (jdata : Dict [str , Any ]) -> Dict [str , Any ]:
9992 """Convert data to v1 input for smooth descriptor.
100-
10193 Parameters
10294 ----------
10395 jdata : Dict[str, Any]
10496 parsed input json/yaml data
105-
10697 Returns
10798 -------
10899 Dict[str, Any]
@@ -127,12 +118,10 @@ def _smth_descriptor(jdata: Dict[str, Any]) -> Dict[str, Any]:
127118
128119def _fitting_net (jdata : Dict [str , Any ]) -> Dict [str , Any ]:
129120 """Convert data to v1 input for fitting net.
130-
131121 Parameters
132122 ----------
133123 jdata : Dict[str, Any]
134124 parsed input json/yaml data
135-
136125 Returns
137126 -------
138127 Dict[str, Any]
@@ -154,12 +143,10 @@ def _fitting_net(jdata: Dict[str, Any]) -> Dict[str, Any]:
154143
155144def _learning_rate (jdata : Dict [str , Any ]) -> Dict [str , Any ]:
156145 """Convert data to v1 input for learning rate section.
157-
158146 Parameters
159147 ----------
160148 jdata : Dict[str, Any]
161149 parsed input json/yaml data
162-
163150 Returns
164151 -------
165152 Dict[str, Any]
@@ -173,12 +160,10 @@ def _learning_rate(jdata: Dict[str, Any]) -> Dict[str, Any]:
173160
174161def _loss (jdata : Dict [str , Any ]) -> Dict [str , Any ]:
175162 """Convert data to v1 input for loss function.
176-
177163 Parameters
178164 ----------
179165 jdata : Dict[str, Any]
180166 parsed input json/yaml data
181-
182167 Returns
183168 -------
184169 Dict[str, Any]
@@ -206,12 +191,10 @@ def _loss(jdata: Dict[str, Any]) -> Dict[str, Any]:
206191
207192def _training (jdata : Dict [str , Any ]) -> Dict [str , Any ]:
208193 """Convert data to v1 input for training.
209-
210194 Parameters
211195 ----------
212196 jdata : Dict[str, Any]
213197 parsed input json/yaml data
214-
215198 Returns
216199 -------
217200 Dict[str, Any]
@@ -241,7 +224,6 @@ def _training(jdata: Dict[str, Any]) -> Dict[str, Any]:
241224
242225def _jcopy (src : Dict [str , Any ], dst : Dict [str , Any ], keys : Sequence [str ]):
243226 """Copy specified keys from one dict to another.
244-
245227 Parameters
246228 ----------
247229 src : Dict[str, Any]
@@ -255,3 +237,62 @@ def _jcopy(src: Dict[str, Any], dst: Dict[str, Any], keys: Sequence[str]):
255237 """
256238 for k in keys :
257239 dst [k ] = src [k ]
240+
241+
242+ def convert_input_v1_v2 (jdata : Dict [str , Any ],
243+ warning : bool = True ,
244+ dump : Optional [Union [str , Path ]] = None ) -> Dict [str , Any ]:
245+
246+ tr_cfg = jdata ["training" ]
247+ tr_data_keys = {
248+ "systems" ,
249+ "set_prefix" ,
250+ "batch_size" ,
251+ "sys_prob" ,
252+ "auto_prob" ,
253+ # alias included
254+ "sys_weights" ,
255+ "auto_prob_style"
256+ }
257+
258+ tr_data_cfg = {k : v for k , v in tr_cfg .items () if k in tr_data_keys }
259+ new_tr_cfg = {k : v for k , v in tr_cfg .items () if k not in tr_data_keys }
260+ new_tr_cfg ["training_data" ] = tr_data_cfg
261+
262+ jdata ["training" ] = new_tr_cfg
263+
264+ if warning :
265+ _warning_input_v1_v2 (dump )
266+ if dump is not None :
267+ with open (dump , "w" ) as fp :
268+ json .dump (jdata , fp , indent = 4 )
269+
270+ return jdata
271+
272+
273+ def _warning_input_v1_v2 (fname : Optional [Union [str , Path ]]):
274+ msg = "It seems that you are using a deepmd-kit input of version 1.x.x, " \
275+ "which is deprecated. we have converted the input to >2.0.0 compatible"
276+ if fname is not None :
277+ msg += f", and output it to file { fname } "
278+ warnings .warn (msg )
279+
280+
281+ def updata_deepmd_input (jdata : Dict [str , Any ],
282+ warning : bool = True ,
283+ dump : Optional [Union [str , Path ]] = None ) -> Dict [str , Any ]:
284+ def is_deepmd_v0_input (jdata ):
285+ return "model" not in jdata .keys ()
286+
287+ def is_deepmd_v1_input (jdata ):
288+ return "systems" in j_must_have (jdata , "training" ).keys ()
289+
290+ if is_deepmd_v0_input (jdata ):
291+ jdata = convert_input_v0_v1 (jdata , warning , None )
292+ jdata = convert_input_v1_v2 (jdata , False , dump )
293+ elif is_deepmd_v1_input (jdata ):
294+ jdata = convert_input_v1_v2 (jdata , warning , dump )
295+ else :
296+ pass
297+
298+ return jdata
0 commit comments