@@ -894,17 +894,9 @@ def test_eval_typeebd(self):
894894 np .testing .assert_almost_equal (eval_typeebd , expected_typeebd , default_places )
895895
896896
897- class TestFparamAparam ( unittest . TestCase ) :
897+ class FparamAparamCommonTest :
898898 """Test fparam and aparam."""
899899
900- @classmethod
901- def setUpClass (cls ):
902- convert_pbtxt_to_pb (
903- str (infer_path / os .path .join ("fparam_aparam.pbtxt" )),
904- "fparam_aparam.pb" ,
905- )
906- cls .dp = DeepPot ("fparam_aparam.pb" )
907-
908900 def setUp (self ):
909901 self .coords = np .array (
910902 [
@@ -1022,15 +1014,11 @@ def setUp(self):
10221014 2.875323131744185121e-02 ,
10231015 ]
10241016 )
1025-
1026- @classmethod
1027- def tearDownClass (cls ):
1028- os .remove ("fparam_aparam.pb" )
1029- cls .dp = None
1017+ self .places = default_places
10301018
10311019 def test_attrs (self ):
10321020 self .assertEqual (self .dp .get_ntypes (), 1 )
1033- self .assertAlmostEqual (self .dp .get_rcut (), 6.0 , places = default_places )
1021+ self .assertAlmostEqual (self .dp .get_rcut (), 6.0 , places = self . places )
10341022 self .assertEqual (self .dp .get_dim_fparam (), 1 )
10351023 self .assertEqual (self .dp .get_dim_aparam (), 1 )
10361024
@@ -1050,13 +1038,11 @@ def test_1frame(self):
10501038 self .assertEqual (ff .shape , (nframes , natoms , 3 ))
10511039 self .assertEqual (vv .shape , (nframes , 9 ))
10521040 # check values
1053- np .testing .assert_almost_equal (
1054- ff .ravel (), self .expected_f .ravel (), default_places
1055- )
1041+ np .testing .assert_almost_equal (ff .ravel (), self .expected_f .ravel (), self .places )
10561042 expected_se = np .sum (self .expected_e .reshape ([nframes , - 1 ]), axis = 1 )
1057- np .testing .assert_almost_equal (ee .ravel (), expected_se .ravel (), default_places )
1043+ np .testing .assert_almost_equal (ee .ravel (), expected_se .ravel (), self . places )
10581044 expected_sv = np .sum (self .expected_v .reshape ([nframes , - 1 , 9 ]), axis = 1 )
1059- np .testing .assert_almost_equal (vv .ravel (), expected_sv .ravel (), default_places )
1045+ np .testing .assert_almost_equal (vv .ravel (), expected_sv .ravel (), self . places )
10601046
10611047 def test_1frame_atm (self ):
10621048 ee , ff , vv , ae , av = self .dp .eval (
@@ -1076,19 +1062,13 @@ def test_1frame_atm(self):
10761062 self .assertEqual (ae .shape , (nframes , natoms , 1 ))
10771063 self .assertEqual (av .shape , (nframes , natoms , 9 ))
10781064 # check values
1079- np .testing .assert_almost_equal (
1080- ff .ravel (), self .expected_f .ravel (), default_places
1081- )
1082- np .testing .assert_almost_equal (
1083- ae .ravel (), self .expected_e .ravel (), default_places
1084- )
1085- np .testing .assert_almost_equal (
1086- av .ravel (), self .expected_v .ravel (), default_places
1087- )
1065+ np .testing .assert_almost_equal (ff .ravel (), self .expected_f .ravel (), self .places )
1066+ np .testing .assert_almost_equal (ae .ravel (), self .expected_e .ravel (), self .places )
1067+ np .testing .assert_almost_equal (av .ravel (), self .expected_v .ravel (), self .places )
10881068 expected_se = np .sum (self .expected_e .reshape ([nframes , - 1 ]), axis = 1 )
1089- np .testing .assert_almost_equal (ee .ravel (), expected_se .ravel (), default_places )
1069+ np .testing .assert_almost_equal (ee .ravel (), expected_se .ravel (), self . places )
10901070 expected_sv = np .sum (self .expected_v .reshape ([nframes , - 1 , 9 ]), axis = 1 )
1091- np .testing .assert_almost_equal (vv .ravel (), expected_sv .ravel (), default_places )
1071+ np .testing .assert_almost_equal (vv .ravel (), expected_sv .ravel (), self . places )
10921072
10931073 def test_2frame_atm_single_param (self ):
10941074 coords2 = np .concatenate ((self .coords , self .coords ))
@@ -1113,13 +1093,13 @@ def test_2frame_atm_single_param(self):
11131093 expected_f = np .concatenate ((self .expected_f , self .expected_f ), axis = 0 )
11141094 expected_e = np .concatenate ((self .expected_e , self .expected_e ), axis = 0 )
11151095 expected_v = np .concatenate ((self .expected_v , self .expected_v ), axis = 0 )
1116- np .testing .assert_almost_equal (ff .ravel (), expected_f .ravel (), default_places )
1117- np .testing .assert_almost_equal (ae .ravel (), expected_e .ravel (), default_places )
1118- np .testing .assert_almost_equal (av .ravel (), expected_v .ravel (), default_places )
1096+ np .testing .assert_almost_equal (ff .ravel (), expected_f .ravel (), self . places )
1097+ np .testing .assert_almost_equal (ae .ravel (), expected_e .ravel (), self . places )
1098+ np .testing .assert_almost_equal (av .ravel (), expected_v .ravel (), self . places )
11191099 expected_se = np .sum (expected_e .reshape ([nframes , - 1 ]), axis = 1 )
1120- np .testing .assert_almost_equal (ee .ravel (), expected_se .ravel (), default_places )
1100+ np .testing .assert_almost_equal (ee .ravel (), expected_se .ravel (), self . places )
11211101 expected_sv = np .sum (expected_v .reshape ([nframes , - 1 , 9 ]), axis = 1 )
1122- np .testing .assert_almost_equal (vv .ravel (), expected_sv .ravel (), default_places )
1102+ np .testing .assert_almost_equal (vv .ravel (), expected_sv .ravel (), self . places )
11231103
11241104 def test_2frame_atm_all_param (self ):
11251105 coords2 = np .concatenate ((self .coords , self .coords ))
@@ -1144,13 +1124,28 @@ def test_2frame_atm_all_param(self):
11441124 expected_f = np .concatenate ((self .expected_f , self .expected_f ), axis = 0 )
11451125 expected_e = np .concatenate ((self .expected_e , self .expected_e ), axis = 0 )
11461126 expected_v = np .concatenate ((self .expected_v , self .expected_v ), axis = 0 )
1147- np .testing .assert_almost_equal (ff .ravel (), expected_f .ravel (), default_places )
1148- np .testing .assert_almost_equal (ae .ravel (), expected_e .ravel (), default_places )
1149- np .testing .assert_almost_equal (av .ravel (), expected_v .ravel (), default_places )
1127+ np .testing .assert_almost_equal (ff .ravel (), expected_f .ravel (), self . places )
1128+ np .testing .assert_almost_equal (ae .ravel (), expected_e .ravel (), self . places )
1129+ np .testing .assert_almost_equal (av .ravel (), expected_v .ravel (), self . places )
11501130 expected_se = np .sum (expected_e .reshape ([nframes , - 1 ]), axis = 1 )
1151- np .testing .assert_almost_equal (ee .ravel (), expected_se .ravel (), default_places )
1131+ np .testing .assert_almost_equal (ee .ravel (), expected_se .ravel (), self . places )
11521132 expected_sv = np .sum (expected_v .reshape ([nframes , - 1 , 9 ]), axis = 1 )
1153- np .testing .assert_almost_equal (vv .ravel (), expected_sv .ravel (), default_places )
1133+ np .testing .assert_almost_equal (vv .ravel (), expected_sv .ravel (), self .places )
1134+
1135+
1136+ class TestFparamAparam (FparamAparamCommonTest , unittest .TestCase ):
1137+ @classmethod
1138+ def setUpClass (cls ):
1139+ convert_pbtxt_to_pb (
1140+ str (infer_path / os .path .join ("fparam_aparam.pbtxt" )),
1141+ "fparam_aparam.pb" ,
1142+ )
1143+ cls .dp = DeepPot ("fparam_aparam.pb" )
1144+
1145+ @classmethod
1146+ def tearDownClass (cls ):
1147+ os .remove ("fparam_aparam.pb" )
1148+ cls .dp = None
11541149
11551150
11561151class TestDeepPotAPBCNeighborList (TestDeepPotAPBC ):
0 commit comments