Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions deepmd/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@ def __init__(
# enforce type_map if necessary
self.enforce_type_map = False
if type_map is not None and self.type_map is not None and len(type_map):
missing_elements = [elem for elem in self.type_map if elem not in type_map]
if missing_elements:
raise ValueError(
f"Elements {missing_elements} are not present in the provided `type_map`."
)
if not self.mixed_type:
atom_type_ = [
type_map.index(self.type_map[ii]) for ii in self.atom_type
Expand Down
82 changes: 82 additions & 0 deletions source/tests/tf/test_deepmd_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,10 @@ def test_init_type_map(self) -> None:
self.assertEqual(dd.atom_type[1], 1)
self.assertEqual(dd.type_map, ["bar", "foo", "tar"])

def test_init_type_map_error(self) -> None:
with self.assertRaises(ValueError):
DeepmdData(self.data_name, type_map=["bar"])

def test_load_set(self) -> None:
dd = (
DeepmdData(self.data_name)
Expand Down Expand Up @@ -378,6 +382,84 @@ def _comp_np_mat2(self, first, second) -> None:
np.testing.assert_almost_equal(first, second, places)


class TestDataMixType(unittest.TestCase):
def setUp(self) -> None:
rng = np.random.default_rng(GLOBAL_SEED)
self.data_name = "test_data"
os.makedirs(self.data_name, exist_ok=True)
os.makedirs(os.path.join(self.data_name, "set.foo"), exist_ok=True)
os.makedirs(os.path.join(self.data_name, "set.bar"), exist_ok=True)
os.makedirs(os.path.join(self.data_name, "set.tar"), exist_ok=True)
np.savetxt(os.path.join(self.data_name, "type.raw"), np.array([0, 0]), fmt="%d")
np.savetxt(
os.path.join(self.data_name, "type_map.raw"),
np.array(["foo", "bar"]),
fmt="%s",
)
self.nframes = 5
self.natoms = 2
# coord
path = os.path.join(self.data_name, "set.foo", "coord.npy")
self.coord = rng.random([self.nframes, self.natoms, 3])
np.save(path, np.reshape(self.coord, [self.nframes, -1]))
self.coord = self.coord[:, [1, 0], :]
self.coord = self.coord.reshape([self.nframes, -1])
# coord bar
path = os.path.join(self.data_name, "set.bar", "coord.npy")
self.coord_bar = rng.random([self.nframes, 3 * self.natoms])
np.save(path, self.coord_bar)
self.coord_bar = self.coord_bar.reshape([self.nframes, self.natoms, 3])
self.coord_bar = self.coord_bar[:, [1, 0], :]
self.coord_bar = self.coord_bar.reshape([self.nframes, -1])
# coord tar
path = os.path.join(self.data_name, "set.tar", "coord.npy")
self.coord_tar = rng.random([2, 3 * self.natoms])
np.save(path, self.coord_tar)
self.coord_tar = self.coord_tar.reshape([2, self.natoms, 3])
self.coord_tar = self.coord_tar[:, [1, 0], :]
self.coord_tar = self.coord_tar.reshape([2, -1])
# box
path = os.path.join(self.data_name, "set.foo", "box.npy")
self.box = rng.random([self.nframes, 9])
np.save(path, self.box)
# box bar
path = os.path.join(self.data_name, "set.bar", "box.npy")
self.box_bar = rng.random([self.nframes, 9])
np.save(path, self.box_bar)
# box tar
path = os.path.join(self.data_name, "set.tar", "box.npy")
self.box_tar = rng.random([2, 9])
np.save(path, self.box_tar)
# real_atom_types
path = os.path.join(self.data_name, "set.foo", "real_atom_types.npy")
self.real_atom_types = rng.integers(0, 2, size=[self.nframes, self.natoms])
np.save(path, self.real_atom_types)
# real_atom_types bar
path = os.path.join(self.data_name, "set.bar", "real_atom_types.npy")
self.real_atom_types_bar = rng.integers(0, 2, size=[self.nframes, self.natoms])
np.save(path, self.real_atom_types_bar)
# real_atom_types tar
path = os.path.join(self.data_name, "set.tar", "real_atom_types.npy")
self.real_atom_types_tar = rng.integers(0, 2, size=[2, self.natoms])
np.save(path, self.real_atom_types_tar)

def test_init_type_map(self) -> None:
dd = DeepmdData(self.data_name, type_map=["bar", "foo", "tar"])
self.assertEqual(dd.enforce_type_map, True)
self.assertEqual(dd.type_map, ["bar", "foo", "tar"])
self.assertEqual(dd.mixed_type, True)
self.assertEqual(dd.type_idx_map[0], 1)
self.assertEqual(dd.type_idx_map[1], 0)
self.assertEqual(dd.type_idx_map[2], -1)

def test_init_type_map_error(self) -> None:
with self.assertRaises(ValueError):
DeepmdData(self.data_name, type_map=["foo"])

def tearDown(self) -> None:
shutil.rmtree(self.data_name)


class TestH5Data(unittest.TestCase):
def setUp(self) -> None:
self.data_name = str(tests_path / "test.hdf5")
Expand Down