11"""Test trained DeePMD model."""
22import logging
33from pathlib import Path
4- from typing import TYPE_CHECKING , List , Optional , Tuple
4+ from typing import TYPE_CHECKING , List , Dict , Optional , Tuple
55
66import numpy as np
77from deepmd import DeepPotential
88from deepmd .common import expand_sys_str
99from deepmd .utils .data import DeepmdData
10+ from deepmd .utils .weight_avg import weighted_average
1011
1112if TYPE_CHECKING :
1213 from deepmd .infer import DeepDipole , DeepPolar , DeepPot , DeepWFC
@@ -77,7 +78,7 @@ def test(
7778 data = DeepmdData (system , set_prefix , shuffle_test = shuffle_test , type_map = tmap )
7879
7980 if dp .model_type == "ener" :
80- err , siz = test_ener (
81+ err = test_ener (
8182 dp ,
8283 data ,
8384 system ,
@@ -87,18 +88,15 @@ def test(
8788 append_detail = (cc != 0 ),
8889 )
8990 elif dp .model_type == "dipole" :
90- err , siz = test_dipole (dp , data , numb_test , detail_file , atomic )
91+ err = test_dipole (dp , data , numb_test , detail_file , atomic )
9192 elif dp .model_type == "polar" :
92- err , siz = test_polar (dp , data , numb_test , detail_file , global_polar = False )
93+ err = test_polar (dp , data , numb_test , detail_file , global_polar = False )
9394 elif dp .model_type == "global_polar" :
94- err , siz = test_polar (dp , data , numb_test , detail_file , global_polar = True )
95- elif dp .model_type == "wfc" :
96- err , siz = test_wfc (dp , data , numb_test , detail_file )
95+ err = test_polar (dp , data , numb_test , detail_file , global_polar = True )
9796 log .info ("# ----------------------------------------------- " )
9897 err_coll .append (err )
99- siz_coll .append (siz )
10098
101- avg_err = weighted_average (err_coll , siz_coll )
99+ avg_err = weighted_average (err_coll )
102100
103101 if len (all_sys ) != len (err_coll ):
104102 log .warning ("Not all systems are tested! Check if the systems are valid" )
@@ -119,8 +117,8 @@ def test(
119117 log .info ("# ----------------------------------------------- " )
120118
121119
122- def l2err (diff : np .ndarray ) -> np .ndarray :
123- """Calculate average l2 norm error.
120+ def rmse (diff : np .ndarray ) -> np .ndarray :
121+ """Calculate average root mean square error.
124122
125123 Parameters
126124 ----------
@@ -135,39 +133,6 @@ def l2err(diff: np.ndarray) -> np.ndarray:
135133 return np .sqrt (np .average (diff * diff ))
136134
137135
138- def weighted_average (
139- err_coll : List [List [np .ndarray ]], siz_coll : List [List [int ]]
140- ) -> np .ndarray :
141- """Compute wighted average of prediction errors for model.
142-
143- Parameters
144- ----------
145- err_coll : List[List[np.ndarray]]
146- each item in list represents erros for one model
147- siz_coll : List[List[int]]
148- weight for each model errors
149-
150- Returns
151- -------
152- np.ndarray
153- weighted averages
154- """
155- assert len (err_coll ) == len (siz_coll )
156-
157- nitems = len (err_coll [0 ])
158- sum_err = np .zeros (nitems )
159- sum_siz = np .zeros (nitems )
160- for sys_error , sys_size in zip (err_coll , siz_coll ):
161- for ii in range (nitems ):
162- ee = sys_error [ii ]
163- ss = sys_size [ii ]
164- sum_err [ii ] += ee * ee * ss
165- sum_siz [ii ] += ss
166- for ii in range (nitems ):
167- sum_err [ii ] = np .sqrt (sum_err [ii ] / sum_siz [ii ])
168- return sum_err
169-
170-
171136def save_txt_file (
172137 fname : Path , data : np .ndarray , header : str = "" , append : bool = False
173138):
@@ -280,25 +245,25 @@ def test_ener(
280245 ae = ae .reshape ([numb_test , - 1 ])
281246 av = av .reshape ([numb_test , - 1 ])
282247
283- l2e = l2err (energy - test_data ["energy" ][:numb_test ].reshape ([- 1 , 1 ]))
284- l2f = l2err (force - test_data ["force" ][:numb_test ])
285- l2v = l2err (virial - test_data ["virial" ][:numb_test ])
286- l2ea = l2e / natoms
287- l2va = l2v / natoms
248+ rmse_e = rmse (energy - test_data ["energy" ][:numb_test ].reshape ([- 1 , 1 ]))
249+ rmse_f = rmse (force - test_data ["force" ][:numb_test ])
250+ rmse_v = rmse (virial - test_data ["virial" ][:numb_test ])
251+ rmse_ea = rmse_e / natoms
252+ rmse_va = rmse_v / natoms
288253 if has_atom_ener :
289- l2ae = l2err (
254+ rmse_ae = rmse (
290255 test_data ["atom_ener" ][:numb_test ].reshape ([- 1 ]) - ae .reshape ([- 1 ])
291256 )
292257
293258 # print ("# energies: %s" % energy)
294259 log .info (f"# number of test data : { numb_test :d} " )
295- log .info (f"Energy RMSE : { l2e :e} eV" )
296- log .info (f"Energy RMSE/Natoms : { l2ea :e} eV" )
297- log .info (f"Force RMSE : { l2f :e} eV/A" )
298- log .info (f"Virial RMSE : { l2v :e} eV" )
299- log .info (f"Virial RMSE/Natoms : { l2va :e} eV" )
260+ log .info (f"Energy RMSE : { rmse_e :e} eV" )
261+ log .info (f"Energy RMSE/Natoms : { rmse_ea :e} eV" )
262+ log .info (f"Force RMSE : { rmse_f :e} eV/A" )
263+ log .info (f"Virial RMSE : { rmse_v :e} eV" )
264+ log .info (f"Virial RMSE/Natoms : { rmse_va :e} eV" )
300265 if has_atom_ener :
301- log .info (f"Atomic ener RMSE : { l2ae :e} eV" )
266+ log .info (f"Atomic ener RMSE : { rmse_ae :e} eV" )
302267
303268 if detail_file is not None :
304269 detail_path = Path (detail_file )
@@ -344,20 +309,24 @@ def test_ener(
344309 "pred_vyy pred_vyz pred_vzx pred_vzy pred_vzz" ,
345310 append = append_detail ,
346311 )
347- return [l2ea , l2f , l2va ], [energy .size , force .size , virial .size ]
312+ return {
313+ "rmse_ea" : (rmse_ea , energy .size ),
314+ "rmse_f" : (rmse_f , force .size ),
315+ "rmse_va" : (rmse_va , virial .size ),
316+ }
348317
349318
350- def print_ener_sys_avg (avg : np . ndarray ):
319+ def print_ener_sys_avg (avg : Dict [ str , float ] ):
351320 """Print errors summary for energy type potential.
352321
353322 Parameters
354323 ----------
355324 avg : np.ndarray
356325 array with summaries
357326 """
358- log .info (f"Energy RMSE/Natoms : { avg [0 ]:e} eV" )
359- log .info (f"Force RMSE : { avg [1 ]:e} eV/A" )
360- log .info (f"Virial RMSE/Natoms : { avg [2 ]:e} eV" )
327+ log .info (f"Energy RMSE/Natoms : { avg ['rmse_ea' ]:e} eV" )
328+ log .info (f"Force RMSE : { avg ['rmse_f' ]:e} eV/A" )
329+ log .info (f"Virial RMSE/Natoms : { avg ['rmse_va' ]:e} eV" )
361330
362331
363332def run_test (dp : "DeepTensor" , test_data : dict , numb_test : int ):
@@ -417,10 +386,10 @@ def test_wfc(
417386 )
418387 test_data = data .get_test ()
419388 wfc , numb_test , _ = run_test (dp , test_data , numb_test )
420- l2f = l2err (wfc - test_data ["wfc" ][:numb_test ])
389+ rmse_f = rmse (wfc - test_data ["wfc" ][:numb_test ])
421390
422391 log .info ("# number of test data : {numb_test:d} " )
423- log .info ("WFC RMSE : {l2f :e} eV/A" )
392+ log .info ("WFC RMSE : {rmse_f :e} eV/A" )
424393
425394 if detail_file is not None :
426395 detail_path = Path (detail_file )
@@ -436,7 +405,9 @@ def test_wfc(
436405 pe ,
437406 header = "ref_wfc(12 dofs) predicted_wfc(12 dofs)" ,
438407 )
439- return [l2f ], [wfc .size ]
408+ return {
409+ 'rmse' : (rmse_f , wfc .size )
410+ }
440411
441412
442413def print_wfc_sys_avg (avg ):
@@ -447,7 +418,7 @@ def print_wfc_sys_avg(avg):
447418 avg : np.ndarray
448419 array with summaries
449420 """
450- log .info (f"WFC RMSE : { avg [0 ]:e} eV/A" )
421+ log .info (f"WFC RMSE : { avg ['rmse' ]:e} eV/A" )
451422
452423
453424def test_polar (
@@ -504,15 +475,15 @@ def test_polar(
504475 for ii in sel_type :
505476 sel_natoms += sum (atype == ii )
506477
507- l2f = l2err (polar - test_data ["polarizability" ][:numb_test ])
508- l2fs = l2f / np .sqrt (sel_natoms )
509- l2fa = l2f / sel_natoms
478+ rmse_f = rmse (polar - test_data ["polarizability" ][:numb_test ])
479+ rmse_fs = rmse_f / np .sqrt (sel_natoms )
480+ rmse_fa = rmse_f / sel_natoms
510481
511482 log .info (f"# number of test data : { numb_test :d} " )
512- log .info (f"Polarizability RMSE : { l2f :e} eV/A" )
483+ log .info (f"Polarizability RMSE : { rmse_f :e} eV/A" )
513484 if global_polar :
514- log .info (f"Polarizability RMSE/sqrtN : { l2fs :e} eV/A" )
515- log .info (f"Polarizability RMSE/N : { l2fa :e} eV/A" )
485+ log .info (f"Polarizability RMSE/sqrtN : { rmse_fs :e} eV/A" )
486+ log .info (f"Polarizability RMSE/N : { rmse_fa :e} eV/A" )
516487
517488 if detail_file is not None :
518489 detail_path = Path (detail_file )
@@ -531,7 +502,9 @@ def test_polar(
531502 "data_pzy data_pzz pred_pxx pred_pxy pred_pxz pred_pyx pred_pyy pred_pyz "
532503 "pred_pzx pred_pzy pred_pzz" ,
533504 )
534- return [l2f ], [polar .size ]
505+ return {
506+ "rmse" : (rmse_f , polar .size )
507+ }
535508
536509
537510def print_polar_sys_avg (avg ):
@@ -542,7 +515,7 @@ def print_polar_sys_avg(avg):
542515 avg : np.ndarray
543516 array with summaries
544517 """
545- log .info (f"Polarizability RMSE : { avg [0 ]:e} eV/A" )
518+ log .info (f"Polarizability RMSE : { avg ['rmse' ]:e} eV/A" )
546519
547520
548521def test_dipole (
@@ -584,13 +557,13 @@ def test_dipole(
584557 atoms = dipole .shape [1 ]
585558 dipole = np .sum (dipole ,axis = 1 )
586559
587- l2f = l2err (dipole - test_data ["dipole" ][:numb_test ])
560+ rmse_f = rmse (dipole - test_data ["dipole" ][:numb_test ])
588561
589562 if has_atom_dipole == False :
590- l2f = l2f / atoms
563+ rmse_f = rmse_f / atoms
591564
592565 log .info (f"# number of test data : { numb_test :d} " )
593- log .info (f"Dipole RMSE : { l2f :e} eV/A" )
566+ log .info (f"Dipole RMSE : { rmse_f :e} eV/A" )
594567
595568 if detail_file is not None :
596569 detail_path = Path (detail_file )
@@ -607,7 +580,9 @@ def test_dipole(
607580 pe ,
608581 header = "data_x data_y data_z pred_x pred_y pred_z" ,
609582 )
610- return [l2f ], [dipole .size ]
583+ return {
584+ 'rmse' : (rmse_f , dipole .size )
585+ }
611586
612587
613588def print_dipole_sys_avg (avg ):
@@ -618,4 +593,4 @@ def print_dipole_sys_avg(avg):
618593 avg : np.ndarray
619594 array with summaries
620595 """
621- log .info (f"Dipole RMSE : { avg [0 ]:e} eV/A" )
596+ log .info (f"Dipole RMSE : { avg ['rmse' ]:e} eV/A" )
0 commit comments