diff --git a/deepmd/tf/fit/dipole.py b/deepmd/tf/fit/dipole.py index 978fd958fb..f98d52c7bd 100644 --- a/deepmd/tf/fit/dipole.py +++ b/deepmd/tf/fit/dipole.py @@ -62,6 +62,9 @@ class DipoleFittingSeA(Fitting): The precision of the embedding net parameters. Supported options are |PRECISION| uniform_seed Only for the purpose of backward compatibility, retrieves the old behavior of using the random seed + mixed_types : bool + If true, use a uniform fitting net for all atom types, otherwise use + different fitting nets for different atom types. """ def __init__( @@ -76,6 +79,7 @@ def __init__( activation_function: str = "tanh", precision: str = "default", uniform_seed: bool = False, + mixed_types: bool = False, **kwargs, ) -> None: """Constructor.""" @@ -100,6 +104,7 @@ def __init__( self.useBN = False self.fitting_net_variables = None self.mixed_prec = None + self.mixed_types = mixed_types def get_sel_type(self) -> int: """Get selected type.""" @@ -109,6 +114,7 @@ def get_out_size(self) -> int: """Get the output size. Should be 3.""" return 3 + @cast_precision def _build_lower(self, start_index, natoms, inputs, rot_mat, suffix="", reuse=None): # cut-out inputs inputs_i = tf.slice(inputs, [0, start_index, 0], [-1, natoms, -1]) @@ -172,7 +178,6 @@ def _build_lower(self, start_index, natoms, inputs, rot_mat, suffix="", reuse=No final_layer = tf.reshape(final_layer, [tf.shape(inputs)[0], natoms, 3]) return final_layer - @cast_precision def build( self, input_d: tf.Tensor, @@ -215,8 +220,12 @@ def build( start_index = 0 inputs = tf.reshape(input_d, [-1, natoms[0], self.dim_descrpt]) rot_mat = tf.reshape(rot_mat, [-1, natoms[0], self.dim_rot_mat]) + if nframes is None: + nframes = tf.shape(inputs)[0] - if type_embedding is not None: + if self.mixed_types or type_embedding is not None: + # keep old behavior + self.mixed_types = True nloc_mask = tf.reshape( tf.tile(tf.repeat(self.sel_mask, natoms[2:]), [nframes]), [nframes, -1] ) @@ -228,13 +237,30 @@ def build( self.nloc_masked = tf.shape( tf.reshape(self.atype_nloc_masked, [nframes, -1]) )[1] + + if type_embedding is not None: atype_embed = tf.nn.embedding_lookup(type_embedding, self.atype_nloc_masked) else: atype_embed = None self.atype_embed = atype_embed + if atype_embed is not None: + inputs = tf.reshape( + tf.reshape(inputs, [nframes, natoms[0], self.dim_descrpt])[nloc_mask], + [-1, self.dim_descrpt], + ) + rot_mat = tf.reshape( + tf.reshape(rot_mat, [nframes, natoms[0], self.dim_rot_mat_1 * 3])[ + nloc_mask + ], + [-1, self.dim_rot_mat_1, 3], + ) + atype_embed = tf.cast(atype_embed, self.fitting_precision) + type_shape = atype_embed.get_shape().as_list() + inputs = tf.concat([inputs, atype_embed], axis=1) + self.dim_descrpt = self.dim_descrpt + type_shape[1] - if atype_embed is None: + if not self.mixed_types: count = 0 outs_list = [] for type_i in range(self.ntypes): @@ -255,20 +281,6 @@ def build( count += 1 outs = tf.concat(outs_list, axis=1) else: - inputs = tf.reshape( - tf.reshape(inputs, [nframes, natoms[0], self.dim_descrpt])[nloc_mask], - [-1, self.dim_descrpt], - ) - rot_mat = tf.reshape( - tf.reshape(rot_mat, [nframes, natoms[0], self.dim_rot_mat_1 * 3])[ - nloc_mask - ], - [-1, self.dim_rot_mat_1, 3], - ) - atype_embed = tf.cast(atype_embed, self.fitting_precision) - type_shape = atype_embed.get_shape().as_list() - inputs = tf.concat([inputs, atype_embed], axis=1) - self.dim_descrpt = self.dim_descrpt + type_shape[1] inputs = tf.reshape(inputs, [nframes, self.nloc_masked, self.dim_descrpt]) rot_mat = tf.reshape( rot_mat, [nframes, self.nloc_masked, self.dim_rot_mat_1 * 3] @@ -354,9 +366,7 @@ def serialize(self, suffix: str) -> dict: "ntypes": self.ntypes, "dim_descrpt": self.dim_descrpt, "embedding_width": self.dim_rot_mat_1, - # very bad design: type embedding is not passed to the class - # TODO: refactor the class for type embedding and dipole fitting - "mixed_types": False, + "mixed_types": self.mixed_types, "dim_out": 3, "neuron": self.n_neuron, "resnet_dt": self.resnet_dt, @@ -365,8 +375,7 @@ def serialize(self, suffix: str) -> dict: "exclude_types": [], "nets": self.serialize_network( ntypes=self.ntypes, - # TODO: consider type embeddings in dipole fitting - ndim=1, + ndim=0 if self.mixed_types else 1, in_dim=self.dim_descrpt, out_dim=self.dim_rot_mat_1, neuron=self.n_neuron, diff --git a/deepmd/tf/fit/dos.py b/deepmd/tf/fit/dos.py index 292db8d5b4..7989752e5a 100644 --- a/deepmd/tf/fit/dos.py +++ b/deepmd/tf/fit/dos.py @@ -95,6 +95,9 @@ class DOSFitting(Fitting): use_aparam_as_mask: bool, optional If True, the atomic parameters will be used as a mask that determines the atom is real/virtual. And the aparam will not be used as the atomic parameters for embedding. + mixed_types : bool + If true, use a uniform fitting net for all atom types, otherwise use + different fitting nets for different atom types. """ def __init__( @@ -114,6 +117,7 @@ def __init__( uniform_seed: bool = False, layer_name: Optional[List[Optional[str]]] = None, use_aparam_as_mask: bool = False, + mixed_types: bool = False, **kwargs, ) -> None: """Constructor.""" @@ -171,6 +175,7 @@ def __init__( assert ( len(self.layer_name) == len(self.n_neuron) + 1 ), "length of layer_name should be that of n_neuron + 1" + self.mixed_types = mixed_types def get_numb_fparam(self) -> int: """Get the number of frame parameters.""" @@ -504,13 +509,22 @@ def build( tf.slice(atype_nall, [0, 0], [-1, natoms[0]]), [-1] ) ## lammps will make error if type_embedding is not None: + # keep old behavior + self.mixed_types = True atype_embed = tf.nn.embedding_lookup(type_embedding, self.atype_nloc) else: atype_embed = None self.atype_embed = atype_embed + if atype_embed is not None: + atype_embed = tf.cast(atype_embed, GLOBAL_TF_FLOAT_PRECISION) + type_shape = atype_embed.get_shape().as_list() + inputs = tf.concat( + [tf.reshape(inputs, [-1, self.dim_descrpt]), atype_embed], axis=1 + ) + self.dim_descrpt = self.dim_descrpt + type_shape[1] - if atype_embed is None: + if not self.mixed_types: start_index = 0 outs_list = [] for type_i in range(self.ntypes): @@ -541,13 +555,6 @@ def build( outs = tf.concat(outs_list, axis=1) # with type embedding else: - atype_embed = tf.cast(atype_embed, GLOBAL_TF_FLOAT_PRECISION) - type_shape = atype_embed.get_shape().as_list() - inputs = tf.concat( - [tf.reshape(inputs, [-1, self.dim_descrpt]), atype_embed], axis=1 - ) - original_dim_descrpt = self.dim_descrpt - self.dim_descrpt = self.dim_descrpt + type_shape[1] inputs = tf.reshape(inputs, [-1, natoms[0], self.dim_descrpt]) final_layer = self._build_lower( 0, @@ -700,9 +707,7 @@ def serialize(self, suffix: str = "") -> dict: "var_name": "dos", "ntypes": self.ntypes, "dim_descrpt": self.dim_descrpt, - # very bad design: type embedding is not passed to the class - # TODO: refactor the class for DOSFitting and type embedding - "mixed_types": False, + "mixed_types": self.mixed_types, "dim_out": self.numb_dos, "neuron": self.n_neuron, "resnet_dt": self.resnet_dt, @@ -715,8 +720,7 @@ def serialize(self, suffix: str = "") -> dict: "exclude_types": [], "nets": self.serialize_network( ntypes=self.ntypes, - # TODO: consider type embeddings for DOSFitting - ndim=1, + ndim=0 if self.mixed_types else 1, in_dim=self.dim_descrpt + self.numb_fparam + self.numb_aparam, out_dim=self.numb_dos, neuron=self.n_neuron, diff --git a/deepmd/tf/fit/ener.py b/deepmd/tf/fit/ener.py index b391b00052..7cb5fc62cf 100644 --- a/deepmd/tf/fit/ener.py +++ b/deepmd/tf/fit/ener.py @@ -141,6 +141,9 @@ class EnerFitting(Fitting): use_aparam_as_mask: bool, optional If True, the atomic parameters will be used as a mask that determines the atom is real/virtual. And the aparam will not be used as the atomic parameters for embedding. + mixed_types : bool + If true, use a uniform fitting net for all atom types, otherwise use + different fitting nets for different atom types. """ def __init__( @@ -162,6 +165,7 @@ def __init__( layer_name: Optional[List[Optional[str]]] = None, use_aparam_as_mask: bool = False, spin: Optional[Spin] = None, + mixed_types: bool = False, **kwargs, ) -> None: """Constructor.""" @@ -238,6 +242,7 @@ def __init__( assert ( len(self.layer_name) == len(self.n_neuron) + 1 ), "length of layer_name should be that of n_neuron + 1" + self.mixed_types = mixed_types def get_numb_fparam(self) -> int: """Get the number of frame parameters.""" @@ -585,6 +590,8 @@ def build( ) else: inputs_zero = tf.zeros_like(inputs, dtype=GLOBAL_TF_FLOAT_PRECISION) + else: + inputs_zero = None if bias_atom_e is not None: assert len(bias_atom_e) == self.ntypes @@ -628,13 +635,29 @@ def build( ): type_embedding = nvnmd_cfg.map["t_ebd"] if type_embedding is not None: + # keep old behavior + self.mixed_types = True atype_embed = tf.nn.embedding_lookup(type_embedding, self.atype_nloc) else: atype_embed = None self.atype_embed = atype_embed + original_dim_descrpt = self.dim_descrpt + if atype_embed is not None: + atype_embed = tf.cast(atype_embed, GLOBAL_TF_FLOAT_PRECISION) + type_shape = atype_embed.get_shape().as_list() + inputs = tf.concat( + [tf.reshape(inputs, [-1, self.dim_descrpt]), atype_embed], axis=1 + ) + self.dim_descrpt = self.dim_descrpt + type_shape[1] + if len(self.atom_ener): + assert inputs_zero is not None + inputs_zero = tf.concat( + [tf.reshape(inputs_zero, [-1, original_dim_descrpt]), atype_embed], + axis=1, + ) - if atype_embed is None: + if not self.mixed_types: start_index = 0 outs_list = [] for type_i in range(ntypes_atom): @@ -673,13 +696,6 @@ def build( outs = tf.concat(outs_list, axis=1) # with type embedding else: - atype_embed = tf.cast(atype_embed, GLOBAL_TF_FLOAT_PRECISION) - type_shape = atype_embed.get_shape().as_list() - inputs = tf.concat( - [tf.reshape(inputs, [-1, self.dim_descrpt]), atype_embed], axis=1 - ) - original_dim_descrpt = self.dim_descrpt - self.dim_descrpt = self.dim_descrpt + type_shape[1] inputs = tf.reshape(inputs, [-1, natoms[0], self.dim_descrpt]) final_layer = self._build_lower( 0, @@ -693,10 +709,6 @@ def build( ) if len(self.atom_ener): # remove contribution in vacuum - inputs_zero = tf.concat( - [tf.reshape(inputs_zero, [-1, original_dim_descrpt]), atype_embed], - axis=1, - ) inputs_zero = tf.reshape(inputs_zero, [-1, natoms[0], self.dim_descrpt]) zero_layer = self._build_lower( 0, @@ -892,9 +904,7 @@ def serialize(self, suffix: str = "") -> dict: "var_name": "energy", "ntypes": self.ntypes, "dim_descrpt": self.dim_descrpt, - # very bad design: type embedding is not passed to the class - # TODO: refactor the class for energy fitting and type embedding - "mixed_types": False, + "mixed_types": self.mixed_types, "dim_out": 1, "neuron": self.n_neuron, "resnet_dt": self.resnet_dt, @@ -912,8 +922,7 @@ def serialize(self, suffix: str = "") -> dict: "exclude_types": [], "nets": self.serialize_network( ntypes=self.ntypes, - # TODO: consider type embeddings for type embedding - ndim=1, + ndim=0 if self.mixed_types else 1, in_dim=self.dim_descrpt + self.numb_fparam + self.numb_aparam, neuron=self.n_neuron, activation_function=self.activation_function_name, diff --git a/deepmd/tf/fit/polar.py b/deepmd/tf/fit/polar.py index 21b9587b88..54221ce315 100644 --- a/deepmd/tf/fit/polar.py +++ b/deepmd/tf/fit/polar.py @@ -16,6 +16,7 @@ DescrptSeA, ) from deepmd.tf.env import ( + GLOBAL_TF_FLOAT_PRECISION, tf, ) from deepmd.tf.fit.fitting import ( @@ -72,6 +73,9 @@ class PolarFittingSeA(Fitting): The precision of the embedding net parameters. Supported options are |PRECISION| uniform_seed Only for the purpose of backward compatibility, retrieves the old behavior of using the random seed + mixed_types : bool + If true, use a uniform fitting net for all atom types, otherwise use + different fitting nets for different atom types. """ def __init__( @@ -90,6 +94,7 @@ def __init__( activation_function: str = "tanh", precision: str = "default", uniform_seed: bool = False, + mixed_types: bool = False, **kwargs, ) -> None: """Constructor.""" @@ -142,6 +147,7 @@ def __init__( self.useBN = False self.fitting_net_variables = None self.mixed_prec = None + self.mixed_types = mixed_types def get_sel_type(self) -> List[int]: """Get selected atom types.""" @@ -242,6 +248,7 @@ def compute_output_stats(self, all_stat): np.diagonal(atom_polar[itype].reshape((3, 3))) ) + @cast_precision def _build_lower(self, start_index, natoms, inputs, rot_mat, suffix="", reuse=None): # cut-out inputs inputs_i = tf.slice( @@ -350,7 +357,6 @@ def _build_lower(self, start_index, natoms, inputs, rot_mat, suffix="", reuse=No final_layer = tf.reshape(final_layer, [tf.shape(inputs)[0], natoms, 3, 3]) return final_layer - @cast_precision def build( self, input_d: tf.Tensor, @@ -393,8 +399,12 @@ def build( start_index = 0 inputs = tf.reshape(input_d, [-1, self.dim_descrpt * natoms[0]]) rot_mat = tf.reshape(rot_mat, [-1, self.dim_rot_mat * natoms[0]]) + if nframes is None: + nframes = tf.shape(inputs)[0] - if type_embedding is not None: + if self.mixed_types or type_embedding is not None: + # keep old behavior + self.mixed_types = True # nframes x nloc nloc_mask = tf.reshape( tf.tile(tf.repeat(self.sel_mask, natoms[2:]), [nframes]), [nframes, -1] @@ -423,13 +433,28 @@ def build( self.nloc_masked = tf.shape( tf.reshape(self.atype_nloc_masked, [nframes, -1]) )[1] + + if type_embedding is not None: atype_embed = tf.nn.embedding_lookup(type_embedding, self.atype_nloc_masked) else: atype_embed = None self.atype_embed = atype_embed + if atype_embed is not None: + inputs = tf.reshape( + tf.reshape(inputs, [nframes, natoms[0], self.dim_descrpt])[nloc_mask], + [-1, self.dim_descrpt], + ) + rot_mat = tf.reshape( + tf.reshape(rot_mat, [nframes, natoms[0], self.dim_rot_mat])[nloc_mask], + [-1, self.dim_rot_mat * self.nloc_masked], + ) + atype_embed = tf.cast(atype_embed, self.fitting_precision) + type_shape = atype_embed.get_shape().as_list() + inputs = tf.concat([inputs, atype_embed], axis=1) + self.dim_descrpt = self.dim_descrpt + type_shape[1] - if atype_embed is None: + if not self.mixed_types: count = 0 outs_list = [] for type_i in range(self.ntypes): @@ -450,7 +475,7 @@ def build( final_layer = final_layer + self.constant_matrix[sel_type_idx] * tf.eye( 3, batch_shape=[tf.shape(inputs)[0], natoms[2 + type_i]], - dtype=self.fitting_precision, + dtype=GLOBAL_TF_FLOAT_PRECISION, ) start_index += natoms[2 + type_i] @@ -459,18 +484,6 @@ def build( count += 1 outs = tf.concat(outs_list, axis=1) else: - inputs = tf.reshape( - tf.reshape(inputs, [nframes, natoms[0], self.dim_descrpt])[nloc_mask], - [-1, self.dim_descrpt], - ) - rot_mat = tf.reshape( - tf.reshape(rot_mat, [nframes, natoms[0], self.dim_rot_mat])[nloc_mask], - [-1, self.dim_rot_mat * self.nloc_masked], - ) - atype_embed = tf.cast(atype_embed, self.fitting_precision) - type_shape = atype_embed.get_shape().as_list() - inputs = tf.concat([inputs, atype_embed], axis=1) - self.dim_descrpt = self.dim_descrpt + type_shape[1] inputs = tf.reshape(inputs, [-1, self.dim_descrpt * self.nloc_masked]) final_layer = self._build_lower( 0, self.nloc_masked, inputs, rot_mat, suffix=suffix, reuse=reuse @@ -480,7 +493,7 @@ def build( if self.shift_diag: final_layer += tf.expand_dims( tf.expand_dims(constant_matrix, -1), -1 - ) * tf.eye(3, batch_shape=[1, 1], dtype=self.fitting_precision) + ) * tf.eye(3, batch_shape=[1, 1], dtype=GLOBAL_TF_FLOAT_PRECISION) outs = final_layer tf.summary.histogram("fitting_net_output", outs) @@ -544,9 +557,7 @@ def serialize(self, suffix: str) -> dict: "ntypes": self.ntypes, "dim_descrpt": self.dim_descrpt, "embedding_width": self.dim_rot_mat_1, - # very bad design: type embedding is not passed to the class - # TODO: refactor the class for polar fitting and type embedding - "mixed_types": False, + "mixed_types": self.mixed_types, "dim_out": 3, "neuron": self.n_neuron, "resnet_dt": self.resnet_dt, @@ -558,8 +569,7 @@ def serialize(self, suffix: str) -> dict: "shift_diag": self.shift_diag, "nets": self.serialize_network( ntypes=self.ntypes, - # TODO: consider type embeddings for polar fitting - ndim=1, + ndim=0 if self.mixed_types else 1, in_dim=self.dim_descrpt, out_dim=self.dim_rot_mat_1, neuron=self.n_neuron, diff --git a/deepmd/tf/model/model.py b/deepmd/tf/model/model.py index a0e234a547..83b2a24528 100644 --- a/deepmd/tf/model/model.py +++ b/deepmd/tf/model/model.py @@ -668,6 +668,7 @@ def __init__( spin=self.spin, ntypes=self.descrpt.get_ntypes(), dim_descrpt=self.descrpt.get_dim_out(), + mixed_types=type_embedding is not None or self.descrpt.explicit_ntypes, ) self.rcut = self.descrpt.get_rcut() self.ntypes = self.descrpt.get_ntypes() diff --git a/source/tests/consistent/fitting/test_dipole.py b/source/tests/consistent/fitting/test_dipole.py index 7b5d4d59e8..18a29934ca 100644 --- a/source/tests/consistent/fitting/test_dipole.py +++ b/source/tests/consistent/fitting/test_dipole.py @@ -58,16 +58,6 @@ def data(self) -> dict: "seed": 20240217, } - @property - def skip_tf(self) -> bool: - ( - resnet_dt, - precision, - mixed_types, - ) = self.param - # TODO: mixed_types - return mixed_types or CommonTest.skip_pt - @property def skip_pt(self) -> bool: ( diff --git a/source/tests/consistent/fitting/test_dos.py b/source/tests/consistent/fitting/test_dos.py index 2832d67641..bfdf76c8ff 100644 --- a/source/tests/consistent/fitting/test_dos.py +++ b/source/tests/consistent/fitting/test_dos.py @@ -64,18 +64,6 @@ def data(self) -> dict: "numb_dos": numb_dos, } - @property - def skip_tf(self) -> bool: - ( - resnet_dt, - precision, - mixed_types, - numb_fparam, - numb_dos, - ) = self.param - # TODO: mixed_types - return mixed_types or CommonTest.skip_pt - @property def skip_pt(self) -> bool: ( diff --git a/source/tests/consistent/fitting/test_ener.py b/source/tests/consistent/fitting/test_ener.py index a22bcdb65f..ab314cb9af 100644 --- a/source/tests/consistent/fitting/test_ener.py +++ b/source/tests/consistent/fitting/test_ener.py @@ -64,18 +64,6 @@ def data(self) -> dict: "atom_ener": atom_ener, } - @property - def skip_tf(self) -> bool: - ( - resnet_dt, - precision, - mixed_types, - numb_fparam, - atom_ener, - ) = self.param - # TODO: mixed_types - return mixed_types or CommonTest.skip_pt - @property def skip_pt(self) -> bool: ( diff --git a/source/tests/consistent/fitting/test_polar.py b/source/tests/consistent/fitting/test_polar.py index 7bc11961eb..5b55c6d333 100644 --- a/source/tests/consistent/fitting/test_polar.py +++ b/source/tests/consistent/fitting/test_polar.py @@ -58,16 +58,6 @@ def data(self) -> dict: "seed": 20240217, } - @property - def skip_tf(self) -> bool: - ( - resnet_dt, - precision, - mixed_types, - ) = self.param - # TODO: mixed_types - return mixed_types or CommonTest.skip_pt - @property def skip_pt(self) -> bool: (