Skip to content

Commit 16c6db6

Browse files
iProzdpre-commit-ci[bot]wanghan-iapcm
authored
pt: refact training code (#3359)
This PR - add data_requirement for dataloader - reformat `make_stat_input` and related training code - support single-task & multi-task training --------- Signed-off-by: Duo <50307526+iProzd@users.noreply.github.com> Signed-off-by: Han Wang <92130845+wanghan-iapcm@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Han Wang <92130845+wanghan-iapcm@users.noreply.github.com>
1 parent 54efc03 commit 16c6db6

42 files changed

Lines changed: 1409 additions & 331 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

deepmd/dpmodel/descriptor/hybrid.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,14 @@ def mixed_types(self):
127127
"""
128128
return any(descrpt.mixed_types() for descrpt in self.descrpt_list)
129129

130+
def share_params(self, base_class, shared_level, resume=False):
131+
"""
132+
Share the parameters of self to the base_class with shared_level during multitask training.
133+
If not start from checkpoint (resume is False),
134+
some seperated parameters (e.g. mean and stddev) will be re-calculated across different classes.
135+
"""
136+
raise NotImplementedError
137+
130138
def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None):
131139
"""Update mean and stddev for descriptor elements."""
132140
for descrpt in self.descrpt_list:

deepmd/dpmodel/descriptor/make_base_descriptor.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
abstractmethod,
55
)
66
from typing import (
7+
Callable,
78
List,
89
Optional,
10+
Union,
911
)
1012

1113
from deepmd.common import (
@@ -84,8 +86,19 @@ def mixed_types(self) -> bool:
8486
"""
8587
pass
8688

89+
@abstractmethod
90+
def share_params(self, base_class, shared_level, resume=False):
91+
"""
92+
Share the parameters of self to the base_class with shared_level during multitask training.
93+
If not start from checkpoint (resume is False),
94+
some seperated parameters (e.g. mean and stddev) will be re-calculated across different classes.
95+
"""
96+
pass
97+
8798
def compute_input_stats(
88-
self, merged: List[dict], path: Optional[DPPath] = None
99+
self,
100+
merged: Union[Callable[[], List[dict]], List[dict]],
101+
path: Optional[DPPath] = None,
89102
):
90103
"""Update mean and stddev for descriptor elements."""
91104
raise NotImplementedError

deepmd/dpmodel/descriptor/se_e2_a.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,14 @@ def mixed_types(self):
243243
"""
244244
return False
245245

246+
def share_params(self, base_class, shared_level, resume=False):
247+
"""
248+
Share the parameters of self to the base_class with shared_level during multitask training.
249+
If not start from checkpoint (resume is False),
250+
some seperated parameters (e.g. mean and stddev) will be re-calculated across different classes.
251+
"""
252+
raise NotImplementedError
253+
246254
def get_ntypes(self) -> int:
247255
"""Returns the number of element types."""
248256
return self.ntypes

deepmd/dpmodel/descriptor/se_r.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,14 @@ def mixed_types(self):
203203
"""
204204
return False
205205

206+
def share_params(self, base_class, shared_level, resume=False):
207+
"""
208+
Share the parameters of self to the base_class with shared_level during multitask training.
209+
If not start from checkpoint (resume is False),
210+
some seperated parameters (e.g. mean and stddev) will be re-calculated across different classes.
211+
"""
212+
raise NotImplementedError
213+
206214
def get_ntypes(self) -> int:
207215
"""Returns the number of element types."""
208216
return self.ntypes

deepmd/dpmodel/model/dp_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
23
from deepmd.dpmodel.atomic_model import (
34
DPAtomicModel,
45
)

deepmd/pt/entrypoints/main.py

Lines changed: 23 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,6 @@
5353
from deepmd.pt.utils.multi_task import (
5454
preprocess_shared_params,
5555
)
56-
from deepmd.pt.utils.stat import (
57-
make_stat_input,
58-
)
5956
from deepmd.utils.argcheck import (
6057
normalize,
6158
)
@@ -104,36 +101,23 @@ def get_trainer(
104101
config["model"]["resuming"] = (finetune_model is not None) or (ckpt is not None)
105102

106103
def prepare_trainer_input_single(
107-
model_params_single, data_dict_single, loss_dict_single, suffix=""
104+
model_params_single, data_dict_single, loss_dict_single, suffix="", rank=0
108105
):
109106
training_dataset_params = data_dict_single["training_data"]
110107
type_split = False
111108
if model_params_single["descriptor"]["type"] in ["se_e2_a"]:
112109
type_split = True
113-
validation_dataset_params = data_dict_single["validation_data"]
110+
validation_dataset_params = data_dict_single.get("validation_data", None)
111+
validation_systems = (
112+
validation_dataset_params["systems"] if validation_dataset_params else None
113+
)
114114
training_systems = training_dataset_params["systems"]
115-
validation_systems = validation_dataset_params["systems"]
116-
117-
# noise params
118-
noise_settings = None
119-
if loss_dict_single.get("type", "ener") == "denoise":
120-
noise_settings = {
121-
"noise_type": loss_dict_single.pop("noise_type", "uniform"),
122-
"noise": loss_dict_single.pop("noise", 1.0),
123-
"noise_mode": loss_dict_single.pop("noise_mode", "fix_num"),
124-
"mask_num": loss_dict_single.pop("mask_num", 8),
125-
"mask_prob": loss_dict_single.pop("mask_prob", 0.15),
126-
"same_mask": loss_dict_single.pop("same_mask", False),
127-
"mask_coord": loss_dict_single.pop("mask_coord", False),
128-
"mask_type": loss_dict_single.pop("mask_type", False),
129-
"max_fail_num": loss_dict_single.pop("max_fail_num", 10),
130-
"mask_type_idx": len(model_params_single["type_map"]) - 1,
131-
}
132-
# noise_settings = None
133115

134116
# stat files
135117
stat_file_path_single = data_dict_single.get("stat_file", None)
136-
if stat_file_path_single is not None:
118+
if rank != 0:
119+
stat_file_path_single = None
120+
elif stat_file_path_single is not None:
137121
if Path(stat_file_path_single).is_dir():
138122
raise ValueError(
139123
f"stat_file should be a file, not a directory: {stat_file_path_single}"
@@ -144,71 +128,63 @@ def prepare_trainer_input_single(
144128
stat_file_path_single = DPPath(stat_file_path_single, "a")
145129

146130
# validation and training data
147-
validation_data_single = DpLoaderSet(
148-
validation_systems,
149-
validation_dataset_params["batch_size"],
150-
model_params_single,
131+
validation_data_single = (
132+
DpLoaderSet(
133+
validation_systems,
134+
validation_dataset_params["batch_size"],
135+
model_params_single,
136+
)
137+
if validation_systems
138+
else None
151139
)
152140
if ckpt or finetune_model:
153141
train_data_single = DpLoaderSet(
154142
training_systems,
155143
training_dataset_params["batch_size"],
156144
model_params_single,
157145
)
158-
sampled_single = None
159146
else:
160147
train_data_single = DpLoaderSet(
161148
training_systems,
162149
training_dataset_params["batch_size"],
163150
model_params_single,
164151
)
165-
data_stat_nbatch = model_params_single.get("data_stat_nbatch", 10)
166-
sampled_single = make_stat_input(
167-
train_data_single.systems,
168-
train_data_single.dataloaders,
169-
data_stat_nbatch,
170-
)
171-
if noise_settings is not None:
172-
train_data_single = DpLoaderSet(
173-
training_systems,
174-
training_dataset_params["batch_size"],
175-
model_params_single,
176-
)
177152
return (
178153
train_data_single,
179154
validation_data_single,
180-
sampled_single,
181155
stat_file_path_single,
182156
)
183157

158+
rank = dist.get_rank() if dist.is_initialized() else 0
184159
if not multi_task:
185160
(
186161
train_data,
187162
validation_data,
188-
sampled,
189163
stat_file_path,
190164
) = prepare_trainer_input_single(
191-
config["model"], config["training"], config["loss"]
165+
config["model"],
166+
config["training"],
167+
config["loss"],
168+
rank=rank,
192169
)
193170
else:
194-
train_data, validation_data, sampled, stat_file_path = {}, {}, {}, {}
171+
train_data, validation_data, stat_file_path = {}, {}, {}
195172
for model_key in config["model"]["model_dict"]:
196173
(
197174
train_data[model_key],
198175
validation_data[model_key],
199-
sampled[model_key],
200176
stat_file_path[model_key],
201177
) = prepare_trainer_input_single(
202178
config["model"]["model_dict"][model_key],
203179
config["training"]["data_dict"][model_key],
204180
config["loss_dict"][model_key],
205181
suffix=f"_{model_key}",
182+
rank=rank,
206183
)
207184

208185
trainer = training.Trainer(
209186
config,
210187
train_data,
211-
sampled=sampled,
212188
stat_file_path=stat_file_path,
213189
validation_data=validation_data,
214190
init_model=init_model,

deepmd/pt/loss/ener.py

Lines changed: 106 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from typing import (
3+
List,
4+
)
5+
26
import torch
37
import torch.nn.functional as F
48

@@ -11,6 +15,9 @@
1115
from deepmd.pt.utils.env import (
1216
GLOBAL_PT_FLOAT_PRECISION,
1317
)
18+
from deepmd.utils.data import (
19+
DataRequirementItem,
20+
)
1421

1522

1623
class EnergyStdLoss(TaskLoss):
@@ -23,16 +30,57 @@ def __init__(
2330
limit_pref_f=0.0,
2431
start_pref_v=0.0,
2532
limit_pref_v=0.0,
33+
start_pref_ae: float = 0.0,
34+
limit_pref_ae: float = 0.0,
35+
start_pref_pf: float = 0.0,
36+
limit_pref_pf: float = 0.0,
2637
use_l1_all: bool = False,
2738
inference=False,
2839
**kwargs,
2940
):
30-
"""Construct a layer to compute loss on energy, force and virial."""
41+
r"""Construct a layer to compute loss on energy, force and virial.
42+
43+
Parameters
44+
----------
45+
starter_learning_rate : float
46+
The learning rate at the start of the training.
47+
start_pref_e : float
48+
The prefactor of energy loss at the start of the training.
49+
limit_pref_e : float
50+
The prefactor of energy loss at the end of the training.
51+
start_pref_f : float
52+
The prefactor of force loss at the start of the training.
53+
limit_pref_f : float
54+
The prefactor of force loss at the end of the training.
55+
start_pref_v : float
56+
The prefactor of virial loss at the start of the training.
57+
limit_pref_v : float
58+
The prefactor of virial loss at the end of the training.
59+
start_pref_ae : float
60+
The prefactor of atomic energy loss at the start of the training.
61+
limit_pref_ae : float
62+
The prefactor of atomic energy loss at the end of the training.
63+
start_pref_pf : float
64+
The prefactor of atomic prefactor force loss at the start of the training.
65+
limit_pref_pf : float
66+
The prefactor of atomic prefactor force loss at the end of the training.
67+
use_l1_all : bool
68+
Whether to use L1 loss, if False (default), it will use L2 loss.
69+
inference : bool
70+
If true, it will output all losses found in output, ignoring the pre-factors.
71+
**kwargs
72+
Other keyword arguments.
73+
"""
3174
super().__init__()
3275
self.starter_learning_rate = starter_learning_rate
3376
self.has_e = (start_pref_e != 0.0 and limit_pref_e != 0.0) or inference
3477
self.has_f = (start_pref_f != 0.0 and limit_pref_f != 0.0) or inference
3578
self.has_v = (start_pref_v != 0.0 and limit_pref_v != 0.0) or inference
79+
80+
# TODO need support for atomic energy and atomic pref
81+
self.has_ae = (start_pref_ae != 0.0 and limit_pref_ae != 0.0) or inference
82+
self.has_pf = (start_pref_pf != 0.0 and limit_pref_pf != 0.0) or inference
83+
3684
self.start_pref_e = start_pref_e
3785
self.limit_pref_e = limit_pref_e
3886
self.start_pref_f = start_pref_f
@@ -153,3 +201,60 @@ def forward(self, model_pred, label, natoms, learning_rate, mae=False):
153201
if not self.inference:
154202
more_loss["rmse"] = torch.sqrt(loss.detach())
155203
return loss, more_loss
204+
205+
@property
206+
def label_requirement(self) -> List[DataRequirementItem]:
207+
"""Return data label requirements needed for this loss calculation."""
208+
label_requirement = []
209+
if self.has_e:
210+
label_requirement.append(
211+
DataRequirementItem(
212+
"energy",
213+
ndof=1,
214+
atomic=False,
215+
must=False,
216+
high_prec=True,
217+
)
218+
)
219+
if self.has_f:
220+
label_requirement.append(
221+
DataRequirementItem(
222+
"force",
223+
ndof=3,
224+
atomic=True,
225+
must=False,
226+
high_prec=False,
227+
)
228+
)
229+
if self.has_v:
230+
label_requirement.append(
231+
DataRequirementItem(
232+
"virial",
233+
ndof=9,
234+
atomic=False,
235+
must=False,
236+
high_prec=False,
237+
)
238+
)
239+
if self.has_ae:
240+
label_requirement.append(
241+
DataRequirementItem(
242+
"atom_ener",
243+
ndof=1,
244+
atomic=True,
245+
must=False,
246+
high_prec=False,
247+
)
248+
)
249+
if self.has_pf:
250+
label_requirement.append(
251+
DataRequirementItem(
252+
"atom_pref",
253+
ndof=1,
254+
atomic=True,
255+
must=False,
256+
high_prec=False,
257+
repeat=3,
258+
)
259+
)
260+
return label_requirement

deepmd/pt/loss/loss.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,30 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from abc import (
3+
ABC,
4+
abstractmethod,
5+
)
6+
from typing import (
7+
List,
8+
)
9+
210
import torch
311

12+
from deepmd.utils.data import (
13+
DataRequirementItem,
14+
)
15+
416

5-
class TaskLoss(torch.nn.Module):
17+
class TaskLoss(torch.nn.Module, ABC):
618
def __init__(self, **kwargs):
719
"""Construct loss."""
820
super().__init__()
921

1022
def forward(self, model_pred, label, natoms, learning_rate):
1123
"""Return loss ."""
1224
raise NotImplementedError
25+
26+
@property
27+
@abstractmethod
28+
def label_requirement(self) -> List[DataRequirementItem]:
29+
"""Return data label requirements needed for this loss calculation."""
30+
pass

0 commit comments

Comments
 (0)