Skip to content

Commit 7e1f36d

Browse files
committed
not using atom_pref as mask when calculating raw force metrics
1 parent e74b4bd commit 7e1f36d

2 files changed

Lines changed: 5 additions & 17 deletions

File tree

deepmd/entrypoints/test.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -419,28 +419,18 @@ def test_ener(
419419
mae_e = mae(diff_e)
420420
rmse_e = rmse(diff_e)
421421
diff_f = force - test_data["force"][:numb_test]
422+
mae_f = mae(diff_f)
423+
rmse_f = rmse(diff_f)
424+
size_f = diff_f.size
422425
if find_atom_pref == 1:
423426
atom_weight = test_data["atom_pref"][:numb_test]
424-
mask = atom_weight != 0
425-
masked_diff = diff_f[mask]
426-
if masked_diff.size > 0:
427-
mae_f = np.mean(np.abs(masked_diff))
428-
rmse_f = np.sqrt(np.mean(masked_diff * masked_diff))
429-
else:
430-
mae_f = 0.0
431-
rmse_f = 0.0
432-
size_f = mask.sum()
433427
weight_sum = np.sum(atom_weight)
434428
if weight_sum > 0:
435429
mae_fw = np.sum(np.abs(diff_f) * atom_weight) / weight_sum
436430
rmse_fw = np.sqrt(np.sum(diff_f * diff_f * atom_weight) / weight_sum)
437431
else:
438432
mae_fw = 0.0
439433
rmse_fw = 0.0
440-
else:
441-
mae_f = mae(diff_f)
442-
rmse_f = rmse(diff_f)
443-
size_f = diff_f.size
444434
diff_v = virial - test_data["virial"][:numb_test]
445435
mae_v = mae(diff_v)
446436
rmse_v = rmse(diff_v)

source/tests/pt/test_dp_test.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -223,10 +223,8 @@ def test_force_weight(self) -> None:
223223
force_true = test_data["force"][:1]
224224
weight = test_data["atom_pref"][:1]
225225
diff = force_pred - force_true
226-
mask = weight != 0
227-
masked_diff = diff[mask]
228-
mae_unweighted = np.sum(np.abs(masked_diff)) / mask.sum()
229-
rmse_unweighted = np.sqrt(np.sum(masked_diff * masked_diff) / mask.sum())
226+
mae_unweighted = np.sum(np.abs(diff)) / diff.size
227+
rmse_unweighted = np.sqrt(np.sum(diff * diff) / diff.size)
230228
denom = weight.sum()
231229
mae_weighted = np.sum(np.abs(diff) * weight) / denom
232230
rmse_weighted = np.sqrt(np.sum(diff * diff * weight) / denom)

0 commit comments

Comments
 (0)