Skip to content

Commit f49c578

Browse files
authored
Merge pull request #518 from amcadmus/devel
seperate impl of averging errors across systems
2 parents c667d00 + 2186788 commit f49c578

2 files changed

Lines changed: 87 additions & 78 deletions

File tree

deepmd/entrypoints/test.py

Lines changed: 53 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
"""Test trained DeePMD model."""
22
import logging
33
from pathlib import Path
4-
from typing import TYPE_CHECKING, List, Optional, Tuple
4+
from typing import TYPE_CHECKING, List, Dict, Optional, Tuple
55

66
import numpy as np
77
from deepmd import DeepPotential
88
from deepmd.common import expand_sys_str
99
from deepmd.utils.data import DeepmdData
10+
from deepmd.utils.weight_avg import weighted_average
1011

1112
if TYPE_CHECKING:
1213
from deepmd.infer import DeepDipole, DeepPolar, DeepPot, DeepWFC
@@ -77,7 +78,7 @@ def test(
7778
data = DeepmdData(system, set_prefix, shuffle_test=shuffle_test, type_map=tmap)
7879

7980
if dp.model_type == "ener":
80-
err, siz = test_ener(
81+
err = test_ener(
8182
dp,
8283
data,
8384
system,
@@ -87,18 +88,15 @@ def test(
8788
append_detail=(cc != 0),
8889
)
8990
elif dp.model_type == "dipole":
90-
err, siz = test_dipole(dp, data, numb_test, detail_file, atomic)
91+
err = test_dipole(dp, data, numb_test, detail_file, atomic)
9192
elif dp.model_type == "polar":
92-
err, siz = test_polar(dp, data, numb_test, detail_file, global_polar=False)
93+
err = test_polar(dp, data, numb_test, detail_file, global_polar=False)
9394
elif dp.model_type == "global_polar":
94-
err, siz = test_polar(dp, data, numb_test, detail_file, global_polar=True)
95-
elif dp.model_type == "wfc":
96-
err, siz = test_wfc(dp, data, numb_test, detail_file)
95+
err = test_polar(dp, data, numb_test, detail_file, global_polar=True)
9796
log.info("# ----------------------------------------------- ")
9897
err_coll.append(err)
99-
siz_coll.append(siz)
10098

101-
avg_err = weighted_average(err_coll, siz_coll)
99+
avg_err = weighted_average(err_coll)
102100

103101
if len(all_sys) != len(err_coll):
104102
log.warning("Not all systems are tested! Check if the systems are valid")
@@ -119,8 +117,8 @@ def test(
119117
log.info("# ----------------------------------------------- ")
120118

121119

122-
def l2err(diff: np.ndarray) -> np.ndarray:
123-
"""Calculate average l2 norm error.
120+
def rmse(diff: np.ndarray) -> np.ndarray:
121+
"""Calculate average root mean square error.
124122
125123
Parameters
126124
----------
@@ -135,39 +133,6 @@ def l2err(diff: np.ndarray) -> np.ndarray:
135133
return np.sqrt(np.average(diff * diff))
136134

137135

138-
def weighted_average(
139-
err_coll: List[List[np.ndarray]], siz_coll: List[List[int]]
140-
) -> np.ndarray:
141-
"""Compute wighted average of prediction errors for model.
142-
143-
Parameters
144-
----------
145-
err_coll : List[List[np.ndarray]]
146-
each item in list represents erros for one model
147-
siz_coll : List[List[int]]
148-
weight for each model errors
149-
150-
Returns
151-
-------
152-
np.ndarray
153-
weighted averages
154-
"""
155-
assert len(err_coll) == len(siz_coll)
156-
157-
nitems = len(err_coll[0])
158-
sum_err = np.zeros(nitems)
159-
sum_siz = np.zeros(nitems)
160-
for sys_error, sys_size in zip(err_coll, siz_coll):
161-
for ii in range(nitems):
162-
ee = sys_error[ii]
163-
ss = sys_size[ii]
164-
sum_err[ii] += ee * ee * ss
165-
sum_siz[ii] += ss
166-
for ii in range(nitems):
167-
sum_err[ii] = np.sqrt(sum_err[ii] / sum_siz[ii])
168-
return sum_err
169-
170-
171136
def save_txt_file(
172137
fname: Path, data: np.ndarray, header: str = "", append: bool = False
173138
):
@@ -280,25 +245,25 @@ def test_ener(
280245
ae = ae.reshape([numb_test, -1])
281246
av = av.reshape([numb_test, -1])
282247

283-
l2e = l2err(energy - test_data["energy"][:numb_test].reshape([-1, 1]))
284-
l2f = l2err(force - test_data["force"][:numb_test])
285-
l2v = l2err(virial - test_data["virial"][:numb_test])
286-
l2ea = l2e / natoms
287-
l2va = l2v / natoms
248+
rmse_e = rmse(energy - test_data["energy"][:numb_test].reshape([-1, 1]))
249+
rmse_f = rmse(force - test_data["force"][:numb_test])
250+
rmse_v = rmse(virial - test_data["virial"][:numb_test])
251+
rmse_ea = rmse_e / natoms
252+
rmse_va = rmse_v / natoms
288253
if has_atom_ener:
289-
l2ae = l2err(
254+
rmse_ae = rmse(
290255
test_data["atom_ener"][:numb_test].reshape([-1]) - ae.reshape([-1])
291256
)
292257

293258
# print ("# energies: %s" % energy)
294259
log.info(f"# number of test data : {numb_test:d} ")
295-
log.info(f"Energy RMSE : {l2e:e} eV")
296-
log.info(f"Energy RMSE/Natoms : {l2ea:e} eV")
297-
log.info(f"Force RMSE : {l2f:e} eV/A")
298-
log.info(f"Virial RMSE : {l2v:e} eV")
299-
log.info(f"Virial RMSE/Natoms : {l2va:e} eV")
260+
log.info(f"Energy RMSE : {rmse_e:e} eV")
261+
log.info(f"Energy RMSE/Natoms : {rmse_ea:e} eV")
262+
log.info(f"Force RMSE : {rmse_f:e} eV/A")
263+
log.info(f"Virial RMSE : {rmse_v:e} eV")
264+
log.info(f"Virial RMSE/Natoms : {rmse_va:e} eV")
300265
if has_atom_ener:
301-
log.info(f"Atomic ener RMSE : {l2ae:e} eV")
266+
log.info(f"Atomic ener RMSE : {rmse_ae:e} eV")
302267

303268
if detail_file is not None:
304269
detail_path = Path(detail_file)
@@ -344,20 +309,24 @@ def test_ener(
344309
"pred_vyy pred_vyz pred_vzx pred_vzy pred_vzz",
345310
append=append_detail,
346311
)
347-
return [l2ea, l2f, l2va], [energy.size, force.size, virial.size]
312+
return {
313+
"rmse_ea" : (rmse_ea, energy.size),
314+
"rmse_f" : (rmse_f, force.size),
315+
"rmse_va" : (rmse_va, virial.size),
316+
}
348317

349318

350-
def print_ener_sys_avg(avg: np.ndarray):
319+
def print_ener_sys_avg(avg: Dict[str,float]):
351320
"""Print errors summary for energy type potential.
352321
353322
Parameters
354323
----------
355324
avg : np.ndarray
356325
array with summaries
357326
"""
358-
log.info(f"Energy RMSE/Natoms : {avg[0]:e} eV")
359-
log.info(f"Force RMSE : {avg[1]:e} eV/A")
360-
log.info(f"Virial RMSE/Natoms : {avg[2]:e} eV")
327+
log.info(f"Energy RMSE/Natoms : {avg['rmse_ea']:e} eV")
328+
log.info(f"Force RMSE : {avg['rmse_f']:e} eV/A")
329+
log.info(f"Virial RMSE/Natoms : {avg['rmse_va']:e} eV")
361330

362331

363332
def run_test(dp: "DeepTensor", test_data: dict, numb_test: int):
@@ -417,10 +386,10 @@ def test_wfc(
417386
)
418387
test_data = data.get_test()
419388
wfc, numb_test, _ = run_test(dp, test_data, numb_test)
420-
l2f = l2err(wfc - test_data["wfc"][:numb_test])
389+
rmse_f = rmse(wfc - test_data["wfc"][:numb_test])
421390

422391
log.info("# number of test data : {numb_test:d} ")
423-
log.info("WFC RMSE : {l2f:e} eV/A")
392+
log.info("WFC RMSE : {rmse_f:e} eV/A")
424393

425394
if detail_file is not None:
426395
detail_path = Path(detail_file)
@@ -436,7 +405,9 @@ def test_wfc(
436405
pe,
437406
header="ref_wfc(12 dofs) predicted_wfc(12 dofs)",
438407
)
439-
return [l2f], [wfc.size]
408+
return {
409+
'rmse' : (rmse_f, wfc.size)
410+
}
440411

441412

442413
def print_wfc_sys_avg(avg):
@@ -447,7 +418,7 @@ def print_wfc_sys_avg(avg):
447418
avg : np.ndarray
448419
array with summaries
449420
"""
450-
log.info(f"WFC RMSE : {avg[0]:e} eV/A")
421+
log.info(f"WFC RMSE : {avg['rmse']:e} eV/A")
451422

452423

453424
def test_polar(
@@ -504,15 +475,15 @@ def test_polar(
504475
for ii in sel_type:
505476
sel_natoms += sum(atype == ii)
506477

507-
l2f = l2err(polar - test_data["polarizability"][:numb_test])
508-
l2fs = l2f / np.sqrt(sel_natoms)
509-
l2fa = l2f / sel_natoms
478+
rmse_f = rmse(polar - test_data["polarizability"][:numb_test])
479+
rmse_fs = rmse_f / np.sqrt(sel_natoms)
480+
rmse_fa = rmse_f / sel_natoms
510481

511482
log.info(f"# number of test data : {numb_test:d} ")
512-
log.info(f"Polarizability RMSE : {l2f:e} eV/A")
483+
log.info(f"Polarizability RMSE : {rmse_f:e} eV/A")
513484
if global_polar:
514-
log.info(f"Polarizability RMSE/sqrtN : {l2fs:e} eV/A")
515-
log.info(f"Polarizability RMSE/N : {l2fa:e} eV/A")
485+
log.info(f"Polarizability RMSE/sqrtN : {rmse_fs:e} eV/A")
486+
log.info(f"Polarizability RMSE/N : {rmse_fa:e} eV/A")
516487

517488
if detail_file is not None:
518489
detail_path = Path(detail_file)
@@ -531,7 +502,9 @@ def test_polar(
531502
"data_pzy data_pzz pred_pxx pred_pxy pred_pxz pred_pyx pred_pyy pred_pyz "
532503
"pred_pzx pred_pzy pred_pzz",
533504
)
534-
return [l2f], [polar.size]
505+
return {
506+
"rmse" : (rmse_f, polar.size)
507+
}
535508

536509

537510
def print_polar_sys_avg(avg):
@@ -542,7 +515,7 @@ def print_polar_sys_avg(avg):
542515
avg : np.ndarray
543516
array with summaries
544517
"""
545-
log.info(f"Polarizability RMSE : {avg[0]:e} eV/A")
518+
log.info(f"Polarizability RMSE : {avg['rmse']:e} eV/A")
546519

547520

548521
def test_dipole(
@@ -584,13 +557,13 @@ def test_dipole(
584557
atoms = dipole.shape[1]
585558
dipole = np.sum(dipole,axis=1)
586559

587-
l2f = l2err(dipole - test_data["dipole"][:numb_test])
560+
rmse_f = rmse(dipole - test_data["dipole"][:numb_test])
588561

589562
if has_atom_dipole == False:
590-
l2f = l2f / atoms
563+
rmse_f = rmse_f / atoms
591564

592565
log.info(f"# number of test data : {numb_test:d}")
593-
log.info(f"Dipole RMSE : {l2f:e} eV/A")
566+
log.info(f"Dipole RMSE : {rmse_f:e} eV/A")
594567

595568
if detail_file is not None:
596569
detail_path = Path(detail_file)
@@ -607,7 +580,9 @@ def test_dipole(
607580
pe,
608581
header="data_x data_y data_z pred_x pred_y pred_z",
609582
)
610-
return [l2f], [dipole.size]
583+
return {
584+
'rmse' : (rmse_f, dipole.size)
585+
}
611586

612587

613588
def print_dipole_sys_avg(avg):
@@ -618,4 +593,4 @@ def print_dipole_sys_avg(avg):
618593
avg : np.ndarray
619594
array with summaries
620595
"""
621-
log.info(f"Dipole RMSE : {avg[0]:e} eV/A")
596+
log.info(f"Dipole RMSE : {avg['rmse']:e} eV/A")

deepmd/utils/weight_avg.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from typing import TYPE_CHECKING, List, Dict, Optional, Tuple
2+
import numpy as np
3+
4+
5+
def weighted_average(
6+
errors: List[Dict[str, Tuple[float, float]]]
7+
) -> Dict:
8+
"""Compute wighted average of prediction errors for model.
9+
10+
Parameters
11+
----------
12+
errors : List[Dict[str, Tuple[float, float]]]
13+
List: the error of systems
14+
Dict: the error of quantities, name given by the key
15+
Tuple: (error, weight)
16+
17+
Returns
18+
-------
19+
Dict
20+
weighted averages
21+
"""
22+
sum_err = {}
23+
sum_siz = {}
24+
for err in errors:
25+
for kk, (ee, ss) in err.items():
26+
if kk in sum_err:
27+
sum_err[kk] += ee * ee * ss
28+
sum_siz[kk] += ss
29+
else :
30+
sum_err[kk] = ee * ee * ss
31+
sum_siz[kk] = ss
32+
for kk in sum_err.keys():
33+
sum_err[kk] = np.sqrt(sum_err[kk] / sum_siz[kk])
34+
return sum_err

0 commit comments

Comments
 (0)