diff --git a/deepmd/dpmodel/atomic_model/linear_atomic_model.py b/deepmd/dpmodel/atomic_model/linear_atomic_model.py index 48aacacdee..a9e57d6270 100644 --- a/deepmd/dpmodel/atomic_model/linear_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/linear_atomic_model.py @@ -320,14 +320,16 @@ def _compute_weight( ), axis=-1, ) # handle masked nnei. - sigma = numerator / denominator + with np.errstate(divide="ignore", invalid="ignore"): + sigma = numerator / denominator u = (sigma - self.sw_rmin) / (self.sw_rmax - self.sw_rmin) coef = np.zeros_like(u) left_mask = sigma < self.sw_rmin mid_mask = (self.sw_rmin <= sigma) & (sigma < self.sw_rmax) right_mask = sigma >= self.sw_rmax coef[left_mask] = 1 - smooth = -6 * u**5 + 15 * u**4 - 10 * u**3 + 1 + with np.errstate(invalid="ignore"): + smooth = -6 * u**5 + 15 * u**4 - 10 * u**3 + 1 coef[mid_mask] = smooth[mid_mask] coef[right_mask] = 0 self.zbl_weight = coef diff --git a/deepmd/pt/model/task/__init__.py b/deepmd/pt/model/task/__init__.py index 0b21033d31..180e5a5211 100644 --- a/deepmd/pt/model/task/__init__.py +++ b/deepmd/pt/model/task/__init__.py @@ -9,7 +9,7 @@ DenoiseNet, ) from .dipole import ( - DipoleFittingNetType, + DipoleFittingNet, ) from .ener import ( EnergyFittingNet, @@ -25,7 +25,7 @@ __all__ = [ "FittingNetAttenLcc", "DenoiseNet", - "DipoleFittingNetType", + "DipoleFittingNet", "EnergyFittingNet", "EnergyFittingNetDirect", "Fitting", diff --git a/deepmd/pt/model/task/dipole.py b/deepmd/pt/model/task/dipole.py index d911613a5b..aa518d2cd3 100644 --- a/deepmd/pt/model/task/dipole.py +++ b/deepmd/pt/model/task/dipole.py @@ -13,7 +13,7 @@ log = logging.getLogger(__name__) -class DipoleFittingNetType(Fitting): +class DipoleFittingNet(Fitting): def __init__( self, ntypes, embedding_width, neuron, out_dim, resnet_dt=True, **kwargs ): diff --git a/deepmd/pt/model/task/ener.py b/deepmd/pt/model/task/ener.py index 1a883a50a2..f1dad4c58d 100644 --- a/deepmd/pt/model/task/ener.py +++ b/deepmd/pt/model/task/ener.py @@ -1,5 +1,4 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import copy import logging from typing import ( List, @@ -15,30 +14,22 @@ OutputVariableDef, fitting_check_output, ) -from deepmd.pt.model.network.mlp import ( - FittingNet, - NetworkCollection, -) from deepmd.pt.model.network.network import ( ResidualDeep, ) from deepmd.pt.model.task.fitting import ( Fitting, + GeneralFitting, ) from deepmd.pt.utils import ( env, ) from deepmd.pt.utils.env import ( DEFAULT_PRECISION, - PRECISION_DICT, ) from deepmd.pt.utils.stat import ( compute_output_bias, ) -from deepmd.pt.utils.utils import ( - to_numpy_array, - to_torch_tensor, -) dtype = env.GLOBAL_PT_FLOAT_PRECISION device = env.DEVICE @@ -47,7 +38,41 @@ @fitting_check_output -class InvarFitting(Fitting): +class InvarFitting(GeneralFitting): + """Construct a fitting net for energy. + + Parameters + ---------- + var_name : str + The atomic property to fit, 'energy', 'dipole', and 'polar'. + ntypes : int + Element count. + dim_descrpt : int + Embedding width per atom. + dim_out : int + The output dimension of the fitting net. + neuron : List[int] + Number of neurons in each hidden layers of the fitting net. + bias_atom_e : torch.Tensor, optional + Average enery per atom for each element. + resnet_dt : bool + Using time-step in the ResNet construction. + numb_fparam : int + Number of frame parameters. + numb_aparam : int + Number of atomic parameters. + activation_function : str + Activation function. + precision : str + Numerical precision. + distinguish_types : bool + Neighbor list that distinguish different atomic types or not. + rcond : float, optional + The condition number for the regression of atomic energy. + seed : int, optional + Random seed. + """ + def __init__( self, var_name: str, @@ -63,118 +88,31 @@ def __init__( precision: str = DEFAULT_PRECISION, distinguish_types: bool = False, rcond: Optional[float] = None, + seed: Optional[int] = None, **kwargs, ): - """Construct a fitting net for energy. - - Args: - - ntypes: Element count. - - embedding_width: Embedding width per atom. - - neuron: Number of neurons in each hidden layers of the fitting net. - - bias_atom_e: Average enery per atom for each element. - - resnet_dt: Using time-step in the ResNet construction. - """ - super().__init__() - self.var_name = var_name - self.ntypes = ntypes - self.dim_descrpt = dim_descrpt - self.dim_out = dim_out - self.neuron = neuron - self.distinguish_types = distinguish_types - self.use_tebd = not self.distinguish_types - self.resnet_dt = resnet_dt - self.numb_fparam = numb_fparam - self.numb_aparam = numb_aparam - self.activation_function = activation_function - self.precision = precision - self.prec = PRECISION_DICT[self.precision] - self.rcond = rcond - if bias_atom_e is None: - bias_atom_e = np.zeros([self.ntypes, self.dim_out]) - bias_atom_e = torch.tensor(bias_atom_e, dtype=self.prec, device=device) - bias_atom_e = bias_atom_e.view([self.ntypes, self.dim_out]) - if not self.use_tebd: - assert self.ntypes == bias_atom_e.shape[0], "Element count mismatches!" - self.register_buffer("bias_atom_e", bias_atom_e) - # init constants - if self.numb_fparam > 0: - self.register_buffer( - "fparam_avg", - torch.zeros(self.numb_fparam, dtype=self.prec, device=device), - ) - self.register_buffer( - "fparam_inv_std", - torch.ones(self.numb_fparam, dtype=self.prec, device=device), - ) - else: - self.fparam_avg, self.fparam_inv_std = None, None - if self.numb_aparam > 0: - self.register_buffer( - "aparam_avg", - torch.zeros(self.numb_aparam, dtype=self.prec, device=device), - ) - self.register_buffer( - "aparam_inv_std", - torch.ones(self.numb_aparam, dtype=self.prec, device=device), - ) - else: - self.aparam_avg, self.aparam_inv_std = None, None - - in_dim = self.dim_descrpt + self.numb_fparam + self.numb_aparam - - self.old_impl = kwargs.get("old_impl", False) - if self.old_impl: - filter_layers = [] - for type_i in range(self.ntypes): - bias_type = 0.0 - one = ResidualDeep( - type_i, - self.dim_descrpt, - self.neuron, - bias_type, - resnet_dt=self.resnet_dt, - ) - filter_layers.append(one) - self.filter_layers_old = torch.nn.ModuleList(filter_layers) - self.filter_layers = None - else: - self.filter_layers = NetworkCollection( - 1 if self.distinguish_types else 0, - self.ntypes, - network_type="fitting_network", - networks=[ - FittingNet( - in_dim, - self.dim_out, - self.neuron, - self.activation_function, - self.resnet_dt, - self.precision, - bias_out=True, - ) - for ii in range(self.ntypes if self.distinguish_types else 1) - ], - ) - self.filter_layers_old = None - - # very bad design... - if "seed" in kwargs: - log.info("Set seed to %d in fitting net.", kwargs["seed"]) - torch.manual_seed(kwargs["seed"]) - - def output_def(self) -> FittingOutputDef: - return FittingOutputDef( - [ - OutputVariableDef( - self.var_name, - [self.dim_out], - reduciable=True, - r_differentiable=True, - c_differentiable=True, - ), - ] + super().__init__( + var_name=var_name, + ntypes=ntypes, + dim_descrpt=dim_descrpt, + dim_out=dim_out, + neuron=neuron, + bias_atom_e=bias_atom_e, + resnet_dt=resnet_dt, + numb_fparam=numb_fparam, + numb_aparam=numb_aparam, + activation_function=activation_function, + precision=precision, + distinguish_types=distinguish_types, + rcond=rcond, + seed=seed, + **kwargs, ) + def _net_out_dim(self): + """Set the FittingNet output dim.""" + return self.dim_out + def __setitem__(self, key, value): if key in ["bias_atom_e"]: value = value.view([self.ntypes, self.dim_out]) @@ -230,62 +168,6 @@ def init_fitting_stat(self, bias_atom_e=None, **kwargs): ) ) - def serialize(self) -> dict: - """Serialize the fitting to dict.""" - return { - "var_name": self.var_name, - "ntypes": self.ntypes, - "dim_descrpt": self.dim_descrpt, - "dim_out": self.dim_out, - "neuron": self.neuron, - "resnet_dt": self.resnet_dt, - "numb_fparam": self.numb_fparam, - "numb_aparam": self.numb_aparam, - "activation_function": self.activation_function, - "precision": self.precision, - "distinguish_types": self.distinguish_types, - "nets": self.filter_layers.serialize(), - "rcond": self.rcond, - "@variables": { - "bias_atom_e": to_numpy_array(self.bias_atom_e), - "fparam_avg": to_numpy_array(self.fparam_avg), - "fparam_inv_std": to_numpy_array(self.fparam_inv_std), - "aparam_avg": to_numpy_array(self.aparam_avg), - "aparam_inv_std": to_numpy_array(self.aparam_inv_std), - }, - # "rcond": self.rcond , - # "tot_ener_zero": self.tot_ener_zero , - # "trainable": self.trainable , - # "atom_ener": self.atom_ener , - # "layer_name": self.layer_name , - # "use_aparam_as_mask": self.use_aparam_as_mask , - # "spin": self.spin , - ## NOTICE: not supported by far - "tot_ener_zero": False, - "trainable": True, - "atom_ener": None, - "layer_name": None, - "use_aparam_as_mask": False, - "spin": None, - } - - @classmethod - def deserialize(cls, data: dict) -> "InvarFitting": - data = copy.deepcopy(data) - variables = data.pop("@variables") - nets = data.pop("nets") - obj = cls(**data) - for kk in variables.keys(): - obj[kk] = to_torch_tensor(variables[kk]) - obj.filter_layers = NetworkCollection.deserialize(nets) - return obj - - def _extend_f_avg_std(self, xx: torch.Tensor, nb: int) -> torch.Tensor: - return torch.tile(xx.view([1, self.numb_fparam]), [nb, 1]) - - def _extend_a_avg_std(self, xx: torch.Tensor, nb: int, nloc: int) -> torch.Tensor: - return torch.tile(xx.view([1, 1, self.numb_aparam]), [nb, nloc, 1]) - def forward( self, descriptor: torch.Tensor, @@ -306,90 +188,7 @@ def forward( ------- - `torch.Tensor`: Total energy with shape [nframes, natoms[0]]. """ - xx = descriptor - nf, nloc, nd = xx.shape - # NOTICE in tests/pt/test_model.py - # it happens that the user directly access the data memeber self.bias_atom_e - # and set it to a wrong shape! - self.bias_atom_e = self.bias_atom_e.view([self.ntypes, self.dim_out]) - # check input dim - if nd != self.dim_descrpt: - raise ValueError( - "get an input descriptor of dim {nd}," - "which is not consistent with {self.dim_descrpt}." - ) - # check fparam dim, concate to input descriptor - if self.numb_fparam > 0: - assert fparam is not None, "fparam should not be None" - assert self.fparam_avg is not None - assert self.fparam_inv_std is not None - if fparam.shape[-1] != self.numb_fparam: - raise ValueError( - "get an input fparam of dim {fparam.shape[-1]}, ", - "which is not consistent with {self.numb_fparam}.", - ) - fparam = fparam.view([nf, self.numb_fparam]) - nb, _ = fparam.shape - t_fparam_avg = self._extend_f_avg_std(self.fparam_avg, nb) - t_fparam_inv_std = self._extend_f_avg_std(self.fparam_inv_std, nb) - fparam = (fparam - t_fparam_avg) * t_fparam_inv_std - fparam = torch.tile(fparam.reshape([nf, 1, -1]), [1, nloc, 1]) - xx = torch.cat( - [xx, fparam], - dim=-1, - ) - # check aparam dim, concate to input descriptor - if self.numb_aparam > 0: - assert aparam is not None, "aparam should not be None" - assert self.aparam_avg is not None - assert self.aparam_inv_std is not None - if aparam.shape[-1] != self.numb_aparam: - raise ValueError( - "get an input aparam of dim {aparam.shape[-1]}, ", - "which is not consistent with {self.numb_aparam}.", - ) - aparam = aparam.view([nf, nloc, self.numb_aparam]) - nb, nloc, _ = aparam.shape - t_aparam_avg = self._extend_a_avg_std(self.aparam_avg, nb, nloc) - t_aparam_inv_std = self._extend_a_avg_std(self.aparam_inv_std, nb, nloc) - aparam = (aparam - t_aparam_avg) * t_aparam_inv_std - xx = torch.cat( - [xx, aparam], - dim=-1, - ) - - outs = torch.zeros(nf, nloc, self.dim_out) # jit assertion - if self.old_impl: - outs = torch.zeros_like(atype).unsqueeze(-1) # jit assertion - assert self.filter_layers_old is not None - if self.use_tebd: - atom_energy = self.filter_layers_old[0](xx) + self.bias_atom_e[ - atype - ].unsqueeze(-1) - outs = outs + atom_energy # Shape is [nframes, natoms[0], 1] - else: - for type_i, filter_layer in enumerate(self.filter_layers_old): - mask = atype == type_i - atom_energy = filter_layer(xx) - atom_energy = atom_energy + self.bias_atom_e[type_i] - atom_energy = atom_energy * mask.unsqueeze(-1) - outs = outs + atom_energy # Shape is [nframes, natoms[0], 1] - return {"energy": outs.to(env.GLOBAL_PT_FLOAT_PRECISION)} - else: - if self.use_tebd: - atom_energy = ( - self.filter_layers.networks[0](xx) + self.bias_atom_e[atype] - ) - outs = outs + atom_energy # Shape is [nframes, natoms[0], 1] - else: - for type_i, ll in enumerate(self.filter_layers.networks): - mask = (atype == type_i).unsqueeze(-1) - mask = torch.tile(mask, (1, 1, self.dim_out)) - atom_energy = ll(xx) - atom_energy = atom_energy + self.bias_atom_e[type_i] - atom_energy = atom_energy * mask - outs = outs + atom_energy # Shape is [nframes, natoms[0], 1] - return {self.var_name: outs.to(env.GLOBAL_PT_FLOAT_PRECISION)} + return self._forward_common(descriptor, atype, gr, g2, h2, fparam, aparam) @Fitting.register("ener") diff --git a/deepmd/pt/model/task/fitting.py b/deepmd/pt/model/task/fitting.py index 360f545975..b2d8c875ce 100644 --- a/deepmd/pt/model/task/fitting.py +++ b/deepmd/pt/model/task/fitting.py @@ -1,5 +1,9 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import copy import logging +from abc import ( + abstractmethod, +) from typing import ( Callable, List, @@ -10,14 +14,30 @@ import numpy as np import torch +from deepmd.dpmodel import ( + FittingOutputDef, + OutputVariableDef, +) +from deepmd.pt.model.network.mlp import ( + FittingNet, + NetworkCollection, +) +from deepmd.pt.model.network.network import ( + ResidualDeep, +) from deepmd.pt.model.task.base_fitting import ( BaseFitting, ) +from deepmd.pt.utils import ( + env, +) from deepmd.pt.utils.dataloader import ( DpLoaderSet, ) from deepmd.pt.utils.env import ( + DEFAULT_PRECISION, DEVICE, + PRECISION_DICT, ) from deepmd.pt.utils.plugin import ( Plugin, @@ -25,6 +45,13 @@ from deepmd.pt.utils.stat import ( make_stat_input, ) +from deepmd.pt.utils.utils import ( + to_numpy_array, + to_torch_tensor, +) + +dtype = env.GLOBAL_PT_FLOAT_PRECISION +device = env.DEVICE log = logging.getLogger(__name__) @@ -316,3 +343,334 @@ def change_energy_bias( ) ) return None + + +class GeneralFitting(Fitting): + """Construct a general fitting net. + + Parameters + ---------- + var_name : str + The atomic property to fit, 'energy', 'dipole', and 'polar'. + ntypes : int + Element count. + dim_descrpt : int + Embedding width per atom. + dim_out : int + The output dimension of the fitting net. + neuron : List[int] + Number of neurons in each hidden layers of the fitting net. + bias_atom_e : torch.Tensor, optional + Average enery per atom for each element. + resnet_dt : bool + Using time-step in the ResNet construction. + numb_fparam : int + Number of frame parameters. + numb_aparam : int + Number of atomic parameters. + activation_function : str + Activation function. + precision : str + Numerical precision. + distinguish_types : bool + Neighbor list that distinguish different atomic types or not. + rcond : float, optional + The condition number for the regression of atomic energy. + seed : int, optional + Random seed. + """ + + def __init__( + self, + var_name: str, + ntypes: int, + dim_descrpt: int, + dim_out: int, + neuron: List[int] = [128, 128, 128], + bias_atom_e: Optional[torch.Tensor] = None, + resnet_dt: bool = True, + numb_fparam: int = 0, + numb_aparam: int = 0, + activation_function: str = "tanh", + precision: str = DEFAULT_PRECISION, + distinguish_types: bool = False, + rcond: Optional[float] = None, + seed: Optional[int] = None, + **kwargs, + ): + super().__init__() + self.var_name = var_name + self.ntypes = ntypes + self.dim_descrpt = dim_descrpt + self.dim_out = dim_out + self.neuron = neuron + self.distinguish_types = distinguish_types + self.use_tebd = not self.distinguish_types + self.resnet_dt = resnet_dt + self.numb_fparam = numb_fparam + self.numb_aparam = numb_aparam + self.activation_function = activation_function + self.precision = precision + self.prec = PRECISION_DICT[self.precision] + self.rcond = rcond + + # init constants + if bias_atom_e is None: + bias_atom_e = np.zeros([self.ntypes, self.dim_out]) + bias_atom_e = torch.tensor(bias_atom_e, dtype=self.prec, device=device) + bias_atom_e = bias_atom_e.view([self.ntypes, self.dim_out]) + if not self.use_tebd: + assert self.ntypes == bias_atom_e.shape[0], "Element count mismatches!" + self.register_buffer("bias_atom_e", bias_atom_e) + + if self.numb_fparam > 0: + self.register_buffer( + "fparam_avg", + torch.zeros(self.numb_fparam, dtype=self.prec, device=device), + ) + self.register_buffer( + "fparam_inv_std", + torch.ones(self.numb_fparam, dtype=self.prec, device=device), + ) + else: + self.fparam_avg, self.fparam_inv_std = None, None + if self.numb_aparam > 0: + self.register_buffer( + "aparam_avg", + torch.zeros(self.numb_aparam, dtype=self.prec, device=device), + ) + self.register_buffer( + "aparam_inv_std", + torch.ones(self.numb_aparam, dtype=self.prec, device=device), + ) + else: + self.aparam_avg, self.aparam_inv_std = None, None + + in_dim = self.dim_descrpt + self.numb_fparam + self.numb_aparam + + self.old_impl = kwargs.get("old_impl", False) + net_dim_out = self._net_out_dim() + if self.old_impl: + filter_layers = [] + for type_i in range(self.ntypes): + bias_type = 0.0 + one = ResidualDeep( + type_i, + self.dim_descrpt, + self.neuron, + bias_type, + resnet_dt=self.resnet_dt, + ) + filter_layers.append(one) + self.filter_layers_old = torch.nn.ModuleList(filter_layers) + self.filter_layers = None + else: + self.filter_layers = NetworkCollection( + 1 if self.distinguish_types else 0, + self.ntypes, + network_type="fitting_network", + networks=[ + FittingNet( + in_dim, + net_dim_out, + self.neuron, + self.activation_function, + self.resnet_dt, + self.precision, + bias_out=True, + ) + for ii in range(self.ntypes if self.distinguish_types else 1) + ], + ) + self.filter_layers_old = None + + if seed is not None: + log.info("Set seed to %d in fitting net.", seed) + torch.manual_seed(seed) + + def serialize(self) -> dict: + """Serialize the fitting to dict.""" + return { + "var_name": self.var_name, + "ntypes": self.ntypes, + "dim_descrpt": self.dim_descrpt, + "dim_out": self.dim_out, + "neuron": self.neuron, + "resnet_dt": self.resnet_dt, + "numb_fparam": self.numb_fparam, + "numb_aparam": self.numb_aparam, + "activation_function": self.activation_function, + "precision": self.precision, + "distinguish_types": self.distinguish_types, + "nets": self.filter_layers.serialize(), + "rcond": self.rcond, + "@variables": { + "bias_atom_e": to_numpy_array(self.bias_atom_e), + "fparam_avg": to_numpy_array(self.fparam_avg), + "fparam_inv_std": to_numpy_array(self.fparam_inv_std), + "aparam_avg": to_numpy_array(self.aparam_avg), + "aparam_inv_std": to_numpy_array(self.aparam_inv_std), + }, + # "rcond": self.rcond , + # "tot_ener_zero": self.tot_ener_zero , + # "trainable": self.trainable , + # "atom_ener": self.atom_ener , + # "layer_name": self.layer_name , + # "use_aparam_as_mask": self.use_aparam_as_mask , + # "spin": self.spin , + ## NOTICE: not supported by far + "tot_ener_zero": False, + "trainable": True, + "atom_ener": None, + "layer_name": None, + "use_aparam_as_mask": False, + "spin": None, + } + + @classmethod + def deserialize(cls, data: dict) -> "GeneralFitting": + data = copy.deepcopy(data) + variables = data.pop("@variables") + nets = data.pop("nets") + obj = cls(**data) + for kk in variables.keys(): + obj[kk] = to_torch_tensor(variables[kk]) + obj.filter_layers = NetworkCollection.deserialize(nets) + return obj + + def get_dim_fparam(self) -> int: + """Get the number (dimension) of frame parameters of this atomic model.""" + return self.numb_fparam + + def get_dim_aparam(self) -> int: + """Get the number (dimension) of atomic parameters of this atomic model.""" + return self.numb_aparam + + def get_sel_type(self) -> List[int]: + """Get the selected atom types of this model. + + Only atoms with selected atom types have atomic contribution + to the result of the model. + If returning an empty list, all atom types are selected. + """ + return [] + + @abstractmethod + def _net_out_dim(self): + """Set the FittingNet output dim.""" + pass + + def output_def(self) -> FittingOutputDef: + return FittingOutputDef( + [ + OutputVariableDef( + self.var_name, + [self.dim_out], + reduciable=True, + r_differentiable=True, + c_differentiable=True, + ), + ] + ) + + def _extend_f_avg_std(self, xx: torch.Tensor, nb: int) -> torch.Tensor: + return torch.tile(xx.view([1, self.numb_fparam]), [nb, 1]) + + def _extend_a_avg_std(self, xx: torch.Tensor, nb: int, nloc: int) -> torch.Tensor: + return torch.tile(xx.view([1, 1, self.numb_aparam]), [nb, nloc, 1]) + + def _forward_common( + self, + descriptor: torch.Tensor, + atype: torch.Tensor, + gr: Optional[torch.Tensor] = None, + g2: Optional[torch.Tensor] = None, + h2: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, + ): + xx = descriptor + nf, nloc, nd = xx.shape + + if nd != self.dim_descrpt: + raise ValueError( + "get an input descriptor of dim {nd}," + "which is not consistent with {self.dim_descrpt}." + ) + # check fparam dim, concate to input descriptor + if self.numb_fparam > 0: + assert fparam is not None, "fparam should not be None" + assert self.fparam_avg is not None + assert self.fparam_inv_std is not None + if fparam.shape[-1] != self.numb_fparam: + raise ValueError( + "get an input fparam of dim {fparam.shape[-1]}, ", + "which is not consistent with {self.numb_fparam}.", + ) + fparam = fparam.view([nf, self.numb_fparam]) + nb, _ = fparam.shape + t_fparam_avg = self._extend_f_avg_std(self.fparam_avg, nb) + t_fparam_inv_std = self._extend_f_avg_std(self.fparam_inv_std, nb) + fparam = (fparam - t_fparam_avg) * t_fparam_inv_std + fparam = torch.tile(fparam.reshape([nf, 1, -1]), [1, nloc, 1]) + xx = torch.cat( + [xx, fparam], + dim=-1, + ) + # check aparam dim, concate to input descriptor + if self.numb_aparam > 0: + assert aparam is not None, "aparam should not be None" + assert self.aparam_avg is not None + assert self.aparam_inv_std is not None + if aparam.shape[-1] != self.numb_aparam: + raise ValueError( + "get an input aparam of dim {aparam.shape[-1]}, ", + "which is not consistent with {self.numb_aparam}.", + ) + aparam = aparam.view([nf, nloc, self.numb_aparam]) + nb, nloc, _ = aparam.shape + t_aparam_avg = self._extend_a_avg_std(self.aparam_avg, nb, nloc) + t_aparam_inv_std = self._extend_a_avg_std(self.aparam_inv_std, nb, nloc) + aparam = (aparam - t_aparam_avg) * t_aparam_inv_std + xx = torch.cat( + [xx, aparam], + dim=-1, + ) + + outs = torch.zeros( + (nf, nloc, self.dim_out), + dtype=env.GLOBAL_PT_FLOAT_PRECISION, + device=env.DEVICE, + ) # jit assertion + if self.old_impl: + outs = torch.zeros_like(atype).unsqueeze(-1) # jit assertion + assert self.filter_layers_old is not None + if self.use_tebd: + atom_property = self.filter_layers_old[0](xx) + self.bias_atom_e[ + atype + ].unsqueeze(-1) + outs = outs + atom_property # Shape is [nframes, natoms[0], 1] + else: + for type_i, filter_layer in enumerate(self.filter_layers_old): + mask = atype == type_i + atom_property = filter_layer(xx) + atom_property = atom_property + self.bias_atom_e[type_i] + atom_property = atom_property * mask.unsqueeze(-1) + outs = outs + atom_property # Shape is [nframes, natoms[0], 1] + return {self.var_name: outs.to(env.GLOBAL_PT_FLOAT_PRECISION)} + else: + if self.use_tebd: + atom_property = ( + self.filter_layers.networks[0](xx) + self.bias_atom_e[atype] + ) + outs = outs + atom_property # Shape is [nframes, natoms[0], 1] + else: + net_dim_out = self._net_out_dim() + for type_i, ll in enumerate(self.filter_layers.networks): + mask = (atype == type_i).unsqueeze(-1) + mask = torch.tile(mask, (1, 1, net_dim_out)) + atom_property = ll(xx) + atom_property = atom_property + self.bias_atom_e[type_i] + atom_property = atom_property * mask + outs = outs + atom_property # Shape is [nframes, natoms[0], 1] + return {self.var_name: outs.to(env.GLOBAL_PT_FLOAT_PRECISION)} diff --git a/source/tests/pt/model/test_make_hessian_model.py b/source/tests/pt/model/test_make_hessian_model.py index 81aee758bf..6f321b6478 100644 --- a/source/tests/pt/model/test_make_hessian_model.py +++ b/source/tests/pt/model/test_make_hessian_model.py @@ -68,24 +68,28 @@ def test( natoms = self.nloc nf = self.nf nv = self.nv - cell0 = torch.rand([3, 3], dtype=dtype) - cell0 = 1.0 * (cell0 + cell0.T) + 5.0 * torch.eye(3) - cell1 = torch.rand([3, 3], dtype=dtype) - cell1 = 1.0 * (cell1 + cell1.T) + 5.0 * torch.eye(3) + cell0 = torch.rand([3, 3], dtype=dtype, device=env.DEVICE) + cell0 = 1.0 * (cell0 + cell0.T) + 5.0 * torch.eye(3, device=env.DEVICE) + cell1 = torch.rand([3, 3], dtype=dtype, device=env.DEVICE) + cell1 = 1.0 * (cell1 + cell1.T) + 5.0 * torch.eye(3, device=env.DEVICE) cell = torch.stack([cell0, cell1]) - coord = torch.rand([nf, natoms, 3], dtype=dtype) + coord = torch.rand([nf, natoms, 3], dtype=dtype, device=env.DEVICE) coord = torch.matmul(coord, cell) cell = cell.view([nf, 9]) coord = coord.view([nf, natoms * 3]) - atype = torch.stack( - [ - torch.IntTensor([0, 0, 1]), - torch.IntTensor([1, 0, 1]), - ] - ).view([nf, natoms]) + atype = ( + torch.stack( + [ + torch.IntTensor([0, 0, 1]), + torch.IntTensor([1, 0, 1]), + ] + ) + .view([nf, natoms]) + .to(env.DEVICE) + ) nfp, nap = 2, 3 - fparam = torch.rand([nf, nfp], dtype=dtype) - aparam = torch.rand([nf, natoms * nap], dtype=dtype) + fparam = torch.rand([nf, nfp], dtype=dtype, device=env.DEVICE) + aparam = torch.rand([nf, natoms * nap], dtype=dtype, device=env.DEVICE) # forward hess and valu models ret_dict0 = self.model_hess.forward_common( coord, atype, box=cell, fparam=fparam, aparam=aparam