Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 24 additions & 37 deletions deepmd/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class DeepmdData:
modifier
Data modifier that has the method `modify_data`
trn_all_set
Use all sets as training dataset. Otherwise, if the number of sets is more than 1, the last set is left for test.
[DEPRECATED] Deprecated. Now all sets are trained and tested.
sort_atoms : bool
Sort atoms by atom types. Required to enable when the data is directly feeded to
descriptors except mixed types.
Expand Down Expand Up @@ -109,15 +109,6 @@ def __init__(
# make idx map
self.sort_atoms = sort_atoms
self.idx_map = self._make_idx_map(self.atom_type)
# train dirs
self.test_dir = self.dirs[-1]
if trn_all_set:
self.train_dirs = self.dirs
else:
if len(self.dirs) == 1:
self.train_dirs = self.dirs
else:
self.train_dirs = self.dirs[:-1]
self.data_dict = {}
# add box and coord
self.add("box", 9, must=self.pbc)
Expand Down Expand Up @@ -225,7 +216,7 @@ def get_data_dict(self) -> dict:

def check_batch_size(self, batch_size):
"""Check if the system can get a batch of data with `batch_size` frames."""
for ii in self.train_dirs:
for ii in self.dirs:
if self.data_dict["coord"]["high_prec"]:
tmpe = (
(ii / "coord.npy").load_numpy().astype(GLOBAL_ENER_FLOAT_PRECISION)
Expand All @@ -240,24 +231,7 @@ def check_batch_size(self, batch_size):

def check_test_size(self, test_size):
"""Check if the system can get a test dataset with `test_size` frames."""
if self.data_dict["coord"]["high_prec"]:
tmpe = (
(self.test_dir / "coord.npy")
.load_numpy()
.astype(GLOBAL_ENER_FLOAT_PRECISION)
)
else:
tmpe = (
(self.test_dir / "coord.npy")
.load_numpy()
.astype(GLOBAL_NP_FLOAT_PRECISION)
)
if tmpe.ndim == 1:
tmpe = tmpe.reshape([1, -1])
if tmpe.shape[0] < test_size:
return self.test_dir, tmpe.shape[0]
else:
return None
return self.check_batch_size(test_size)

def get_item_torch(self, index: int) -> dict:
"""Get a single frame data . The frame is picked from the data system by index. The index is coded across all the sets.
Expand Down Expand Up @@ -287,7 +261,7 @@ def get_batch(self, batch_size: int) -> dict:
else:
set_size = 0
if self.iterator + batch_size > set_size:
self._load_batch_set(self.train_dirs[self.set_count % self.get_numb_set()])
self._load_batch_set(self.dirs[self.set_count % self.get_numb_set()])
self.set_count += 1
set_size = self.batch_set["coord"].shape[0]
iterator_1 = self.iterator + batch_size
Expand All @@ -307,7 +281,7 @@ def get_test(self, ntests: int = -1) -> dict:
Size of the test data set. If `ntests` is -1, all test data will be get.
"""
if not hasattr(self, "test_set"):
self._load_test_set(self.test_dir, self.shuffle_test)
self._load_test_set(self.shuffle_test)
if ntests == -1:
idx = None
else:
Expand Down Expand Up @@ -340,11 +314,11 @@ def get_atom_type(self) -> List[int]:

def get_numb_set(self) -> int:
"""Get number of training sets."""
return len(self.train_dirs)
return len(self.dirs)

def get_numb_batch(self, batch_size: int, set_idx: int) -> int:
"""Get the number of batches in a set."""
data = self._load_set(self.train_dirs[set_idx])
data = self._load_set(self.dirs[set_idx])
ret = data["coord"].shape[0] // batch_size
if ret == 0:
ret = 1
Expand All @@ -353,7 +327,7 @@ def get_numb_batch(self, batch_size: int, set_idx: int) -> int:
def get_sys_numb_batch(self, batch_size: int) -> int:
"""Get the number of batches in the data system."""
ret = 0
for ii in range(len(self.train_dirs)):
for ii in range(len(self.dirs)):
ret += self.get_numb_batch(batch_size, ii)
return ret

Expand Down Expand Up @@ -388,7 +362,7 @@ def avg(self, key):
info = self.data_dict[key]
ndof = info["ndof"]
eners = []
for ii in self.train_dirs:
for ii in self.dirs:
data = self._load_set(ii)
ei = data[key].reshape([-1, ndof])
eners.append(ei)
Expand Down Expand Up @@ -441,8 +415,21 @@ def _load_batch_set(self, set_name: DPPath):
def reset_get_batch(self):
self.iterator = 0

def _load_test_set(self, set_name: DPPath, shuffle_test):
self.test_set = self._load_set(set_name)
def _load_test_set(self, shuffle_test: bool):
test_sets = []
for ii in self.dirs:
test_set = self._load_set(ii)
test_sets.append(test_set)
# merge test sets
self.test_set = {}
assert len(test_sets) > 0
for kk in test_sets[0]:
if "find_" in kk:
self.test_set[kk] = test_sets[0][kk]
else:
self.test_set[kk] = np.concatenate(
[test_set[kk] for test_set in test_sets], axis=0
)
if shuffle_test:
self.test_set, _ = self._shuffle_data(self.test_set)

Expand Down
33 changes: 25 additions & 8 deletions source/tests/tf/test_deepmd_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@ def setUp(self):
path = os.path.join(self.data_name, "set.bar", "test_frame.npy")
self.test_frame_bar = rng.random([self.nframes, 5])
np.save(path, self.test_frame_bar)
path = os.path.join(self.data_name, "set.tar", "test_frame.npy")
self.test_frame_tar = rng.random([2, 5])
np.save(path, self.test_frame_tar)
# t n
self.test_null = np.zeros([self.nframes, 2 * self.natoms])
# tensor shape
Expand All @@ -162,8 +165,9 @@ def test_init(self):
self.assertEqual(dd.idx_map[0], 1)
self.assertEqual(dd.idx_map[1], 0)
self.assertEqual(dd.type_map, ["foo", "bar"])
self.assertEqual(dd.test_dir, "test_data/set.tar")
self.assertEqual(dd.train_dirs, ["test_data/set.bar", "test_data/set.foo"])
self.assertEqual(
dd.dirs, ["test_data/set.bar", "test_data/set.foo", "test_data/set.tar"]
)

def test_init_type_map(self):
dd = DeepmdData(self.data_name, type_map=["bar", "foo", "tar"])
Expand All @@ -182,7 +186,7 @@ def test_load_set(self):
)
data = dd._load_set(os.path.join(self.data_name, "set.foo"))
nframes = data["coord"].shape[0]
self.assertEqual(dd.get_numb_set(), 2)
self.assertEqual(dd.get_numb_set(), 3)
self.assertEqual(dd.get_type_map(), ["foo", "bar"])
self.assertEqual(dd.get_natoms(), 2)
self.assertEqual(list(dd.get_natoms_vec(3)), [2, 2, 1, 1, 0])
Expand Down Expand Up @@ -257,7 +261,10 @@ def test_avg(self):
dd = DeepmdData(self.data_name).add("test_frame", 5, atomic=False, must=True)
favg = dd.avg("test_frame")
fcmp = np.average(
np.concatenate((self.test_frame, self.test_frame_bar), axis=0), axis=0
np.concatenate(
(self.test_frame, self.test_frame_bar, self.test_frame_tar), axis=0
),
axis=0,
)
np.testing.assert_almost_equal(favg, fcmp, places)

Expand All @@ -266,13 +273,17 @@ def test_check_batch_size(self):
ret = dd.check_batch_size(10)
self.assertEqual(ret, (os.path.join(self.data_name, "set.bar"), 5))
ret = dd.check_batch_size(5)
self.assertEqual(ret, (os.path.join(self.data_name, "set.tar"), 2))
ret = dd.check_batch_size(1)
self.assertEqual(ret, None)

def test_check_test_size(self):
dd = DeepmdData(self.data_name)
ret = dd.check_test_size(10)
self.assertEqual(ret, (os.path.join(self.data_name, "set.bar"), 5))
ret = dd.check_test_size(5)
self.assertEqual(ret, (os.path.join(self.data_name, "set.tar"), 2))
ret = dd.check_test_size(2)
ret = dd.check_test_size(1)
self.assertEqual(ret, None)

def test_get_batch(self):
Expand All @@ -284,6 +295,10 @@ def test_get_batch(self):
data = dd.get_batch(5)
self._comp_np_mat2(np.sort(data["coord"], axis=0), np.sort(self.coord, axis=0))
data = dd.get_batch(5)
self._comp_np_mat2(
np.sort(data["coord"], axis=0), np.sort(self.coord_tar, axis=0)
)
data = dd.get_batch(5)
self._comp_np_mat2(
np.sort(data["coord"], axis=0), np.sort(self.coord_bar, axis=0)
)
Expand All @@ -293,8 +308,11 @@ def test_get_batch(self):
def test_get_test(self):
dd = DeepmdData(self.data_name)
data = dd.get_test()
expected_coord = np.concatenate(
(self.coord_bar, self.coord, self.coord_tar), axis=0
)
self._comp_np_mat2(
np.sort(data["coord"], axis=0), np.sort(self.coord_tar, axis=0)
np.sort(data["coord"], axis=0), np.sort(expected_coord, axis=0)
)

def test_get_nbatch(self):
Expand Down Expand Up @@ -368,8 +386,7 @@ def test_init(self):
dd = DeepmdData(self.data_name)
self.assertEqual(dd.idx_map[0], 0)
self.assertEqual(dd.type_map, ["X"])
self.assertEqual(dd.test_dir, self.data_name + "#/set.000")
self.assertEqual(dd.train_dirs, [self.data_name + "#/set.000"])
self.assertEqual(dd.dirs[0], self.data_name + "#/set.000")

def test_get_batch(self):
dd = DeepmdData(self.data_name)
Expand Down
Loading