Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions deepmd/dpmodel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
PRECISION_DICT,
NativeOP,
)
from .descriptor import (
DescrptSeA,
)
from .model import (
DPAtomicModel,
DPModel,
Expand All @@ -17,13 +20,26 @@
get_reduce_name,
model_check_output,
)
from .utils import (
EmbeddingNet,
EnvMat,
FittingNet,
NativeLayer,
NativeNet,
)

__all__ = [
"DPModel",
"DPAtomicModel",
"PRECISION_DICT",
"DEFAULT_PRECISION",
"NativeOP",
"EnvMat",
"NativeLayer",
"NativeNet",
"EmbeddingNet",
"FittingNet",
"DescrptSeA",
Comment on lines +37 to +42
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason of providing all the classes here?

"ModelOutputDef",
"FittingOutputDef",
"OutputVariableDef",
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(
assert not self.multi_task, "multitask mode currently not supported!"
self.type_split = self.input_param["descriptor"]["type"] in ["se_e2_a"]
self.type_map = self.input_param["type_map"]
self.dp = ModelWrapper(get_model(self.input_param, None).to(DEVICE))
self.dp = ModelWrapper(get_model(self.input_param).to(DEVICE))
self.dp.load_state_dict(state_dict)
self.rcut = self.dp.model["Default"].descriptor.get_rcut()
self.sec = np.cumsum(self.dp.model["Default"].descriptor.get_sel())
Expand Down
8 changes: 8 additions & 0 deletions deepmd/pt/model/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,19 @@ class SomeDescript(Descriptor):

@classmethod
def get_stat_name(cls, config):
"""Get the name for the statistic file of the descriptor."""
descrpt_type = config["type"]
return Descriptor.__plugins.plugins[descrpt_type].get_stat_name(config)

@classmethod
def get_data_stat_key(cls, config):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not a good idea to pass a dict at interface. clearly write what does the method need.

"""Get the keys for the data statistic of the descriptor."""
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one would not understand the method from such a doc str.

descrpt_type = config["type"]
return Descriptor.__plugins.plugins[descrpt_type].get_data_stat_key(config)

@classmethod
def get_data_process_key(cls, config):
"""Get the keys for the data preprocess."""
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one would not understand the method from such a doc str.

descrpt_type = config["type"]
return Descriptor.__plugins.plugins[descrpt_type].get_data_process_key(config)
Comment on lines 58 to 74
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a subclass doesn't implement these subclasses, the program will stuck!

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest adding the error message for this case:

Suggested change
@classmethod
def get_stat_name(cls, config):
"""Get the name for the statistic file of the descriptor."""
descrpt_type = config["type"]
return Descriptor.__plugins.plugins[descrpt_type].get_stat_name(config)
@classmethod
def get_data_stat_key(cls, config):
"""Get the keys for the data statistic of the descriptor."""
descrpt_type = config["type"]
return Descriptor.__plugins.plugins[descrpt_type].get_data_stat_key(config)
@classmethod
def get_data_process_key(cls, config):
"""Get the keys for the data preprocess."""
descrpt_type = config["type"]
return Descriptor.__plugins.plugins[descrpt_type].get_data_process_key(config)
@classmethod
def get_stat_name(cls, config):
"""Get the name for the statistic file of the descriptor."""
if cls is not Descriptor:
raise NotImplementedError("get_stat_name is not implemented!")
descrpt_type = config["type"]
return Descriptor.__plugins.plugins[descrpt_type].get_stat_name(config)
@classmethod
def get_data_stat_key(cls, config):
"""Get the keys for the data statistic of the descriptor."""
if cls is not Descriptor:
raise NotImplementedError("get_data_stat_key is not implemented!")
descrpt_type = config["type"]
return Descriptor.__plugins.plugins[descrpt_type].get_data_stat_key(config)
@classmethod
def get_data_process_key(cls, config):
"""Get the keys for the data preprocess."""
if cls is not Descriptor:
raise NotImplementedError("get_data_process_key is not implemented!")
descrpt_type = config["type"]
return Descriptor.__plugins.plugins[descrpt_type].get_data_process_key(config)


Expand Down
7 changes: 7 additions & 0 deletions deepmd/pt/model/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,19 @@ def init_desc_stat(self, sumr, suma, sumn, sumr2, suma2):

@classmethod
def get_stat_name(cls, config):
"""Get the name for the statistic file of the descriptor."""
descrpt_type = config["type"]
assert descrpt_type in ["dpa1", "se_atten"]
return f'stat_file_dpa1_rcut{config["rcut"]:.2f}_smth{config["rcut_smth"]:.2f}_sel{config["sel"]}.npz'

@classmethod
def get_data_stat_key(cls, config):
"""Get the keys for the data statistic of the descriptor."""
return ["sumr", "suma", "sumn", "sumr2", "suma2"]

@classmethod
def get_data_process_key(cls, config):
"""Get the keys for the data preprocess."""
descrpt_type = config["type"]
assert descrpt_type in ["dpa1", "se_atten"]
return {"sel": config["sel"], "rcut": config["rcut"]}
Expand Down
7 changes: 7 additions & 0 deletions deepmd/pt/model/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,15 +316,22 @@ def init_desc_stat(self, sumr, suma, sumn, sumr2, suma2):

@classmethod
def get_stat_name(cls, config):
"""Get the name for the statistic file of the descriptor."""
descrpt_type = config["type"]
assert descrpt_type in ["dpa2"]
return (
f'stat_file_dpa2_repinit_rcut{config["repinit_rcut"]:.2f}_smth{config["repinit_rcut_smth"]:.2f}_sel{config["repinit_nsel"]}'
f'_repformer_rcut{config["repformer_rcut"]:.2f}_smth{config["repformer_rcut_smth"]:.2f}_sel{config["repformer_nsel"]}.npz'
)

@classmethod
def get_data_stat_key(cls, config):
"""Get the keys for the data statistic of the descriptor."""
return ["sumr", "suma", "sumn", "sumr2", "suma2"]

@classmethod
def get_data_process_key(cls, config):
"""Get the keys for the data preprocess."""
descrpt_type = config["type"]
assert descrpt_type in ["dpa2"]
return {
Expand Down
7 changes: 7 additions & 0 deletions deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,19 @@ def init_desc_stat(self, sumr, suma, sumn, sumr2, suma2):

@classmethod
def get_stat_name(cls, config):
"""Get the name for the statistic file of the descriptor."""
descrpt_type = config["type"]
assert descrpt_type in ["se_e2_a"]
return f'stat_file_sea_rcut{config["rcut"]:.2f}_smth{config["rcut_smth"]:.2f}_sel{config["sel"]}.npz'

@classmethod
def get_data_stat_key(cls, config):
"""Get the keys for the data statistic of the descriptor."""
return ["sumr", "suma", "sumn", "sumr2", "suma2"]

@classmethod
def get_data_process_key(cls, config):
"""Get the keys for the data preprocess."""
descrpt_type = config["type"]
assert descrpt_type in ["se_e2_a"]
return {"sel": config["sel"], "rcut": config["rcut"]}
Expand Down
13 changes: 2 additions & 11 deletions deepmd/pt/model/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
)


def get_model(model_params, sampled=None):
def get_model(model_params):
model_params = copy.deepcopy(model_params)
ntypes = len(model_params["type_map"])
# descriptor
Expand All @@ -35,16 +35,7 @@ def get_model(model_params, sampled=None):
fitting_net["return_energy"] = True
fitting = Fitting(**fitting_net)

return EnergyModel(
descriptor,
fitting,
type_map=model_params["type_map"],
type_embedding=model_params.get("type_embedding", None),
resuming=model_params.get("resuming", False),
stat_file_dir=model_params.get("stat_file_dir", None),
stat_file_path=model_params.get("stat_file_path", None),
sampled=sampled,
)
return EnergyModel(descriptor, fitting, type_map=model_params["type_map"])


__all__ = [
Expand Down
43 changes: 2 additions & 41 deletions deepmd/pt/model/model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,31 +39,9 @@ class DPAtomicModel(BaseModel, BaseAtomicModel):
type_map
Mapping atom type to the name (str) of the type.
For example `type_map[1]` gives the name of the type 1.
type_embedding
Type embedding net
resuming
Whether to resume/fine-tune from checkpoint or not.
stat_file_dir
The directory to the state files.
stat_file_path
The path to the state files.
sampled
Sampled frames to compute the statistics.
"""

# I am enough with the shit interface!
def __init__(
self,
descriptor,
fitting,
type_map: Optional[List[str]],
type_embedding: Optional[dict] = None,
resuming: bool = False,
stat_file_dir=None,
stat_file_path=None,
sampled=None,
**kwargs,
):
def __init__(self, descriptor, fitting, type_map: Optional[List[str]]):
super().__init__()
ntypes = len(type_map)
self.type_map = type_map
Expand All @@ -72,17 +50,6 @@ def __init__(
self.rcut = self.descriptor.get_rcut()
self.sel = self.descriptor.get_sel()
self.fitting_net = fitting
# Statistics
fitting_net = None # TODO: hack!!! not sure if it is correct.
self.compute_or_load_stat(
fitting_net,
ntypes,
resuming=resuming,
type_map=type_map,
stat_file_dir=stat_file_dir,
stat_file_path=stat_file_path,
sampled=sampled,
)

def fitting_output_def(self) -> FittingOutputDef:
"""Get the output def of the fitting net."""
Expand Down Expand Up @@ -122,13 +89,7 @@ def deserialize(cls, data) -> "DPAtomicModel":
fitting_obj = getattr(sys.modules[__name__], data["fitting_name"]).deserialize(
data["fitting"]
)
# TODO: dirty hack to provide type_map and avoid data stat!!!
obj = cls(
descriptor_obj,
fitting_obj,
type_map=data["type_map"],
resuming=True,
)
obj = cls(descriptor_obj, fitting_obj, type_map=data["type_map"])
return obj

def forward_atomic(
Expand Down
Loading