File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -419,28 +419,18 @@ def test_ener(
419419 mae_e = mae (diff_e )
420420 rmse_e = rmse (diff_e )
421421 diff_f = force - test_data ["force" ][:numb_test ]
422+ mae_f = mae (diff_f )
423+ rmse_f = rmse (diff_f )
424+ size_f = diff_f .size
422425 if find_atom_pref == 1 :
423426 atom_weight = test_data ["atom_pref" ][:numb_test ]
424- mask = atom_weight != 0
425- masked_diff = diff_f [mask ]
426- if masked_diff .size > 0 :
427- mae_f = np .mean (np .abs (masked_diff ))
428- rmse_f = np .sqrt (np .mean (masked_diff * masked_diff ))
429- else :
430- mae_f = 0.0
431- rmse_f = 0.0
432- size_f = mask .sum ()
433427 weight_sum = np .sum (atom_weight )
434428 if weight_sum > 0 :
435429 mae_fw = np .sum (np .abs (diff_f ) * atom_weight ) / weight_sum
436430 rmse_fw = np .sqrt (np .sum (diff_f * diff_f * atom_weight ) / weight_sum )
437431 else :
438432 mae_fw = 0.0
439433 rmse_fw = 0.0
440- else :
441- mae_f = mae (diff_f )
442- rmse_f = rmse (diff_f )
443- size_f = diff_f .size
444434 diff_v = virial - test_data ["virial" ][:numb_test ]
445435 mae_v = mae (diff_v )
446436 rmse_v = rmse (diff_v )
Original file line number Diff line number Diff line change @@ -223,10 +223,8 @@ def test_force_weight(self) -> None:
223223 force_true = test_data ["force" ][:1 ]
224224 weight = test_data ["atom_pref" ][:1 ]
225225 diff = force_pred - force_true
226- mask = weight != 0
227- masked_diff = diff [mask ]
228- mae_unweighted = np .sum (np .abs (masked_diff )) / mask .sum ()
229- rmse_unweighted = np .sqrt (np .sum (masked_diff * masked_diff ) / mask .sum ())
226+ mae_unweighted = np .sum (np .abs (diff )) / diff .size
227+ rmse_unweighted = np .sqrt (np .sum (diff * diff ) / diff .size )
230228 denom = weight .sum ()
231229 mae_weighted = np .sum (np .abs (diff ) * weight ) / denom
232230 rmse_weighted = np .sqrt (np .sum (diff * diff * weight ) / denom )
You can’t perform that action at this time.
0 commit comments