Skip to content

Commit 71ec631

Browse files
authored
pt: refactor loss (#3569)
This PR updates the loss interface to allow for a more flexible design. It enables processing input tensors before feeding them into the model, such as denoising operations (fyi @Chengqian-Zhang ). Previously, this was done in the data loader, which was less intuitive and more confusing. Now, users can easily handle these tasks within the loss function itself, as demonstrated in similar implementations in uni-mol: https://github.com/dptech-corp/Uni-Mol/blob/main/unimol/unimol/losses/unimol.py#L20.
1 parent 9c861c2 commit 71ec631

8 files changed

Lines changed: 92 additions & 53 deletions

File tree

deepmd/pt/loss/ener.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -90,20 +90,30 @@ def __init__(
9090
self.use_l1_all = use_l1_all
9191
self.inference = inference
9292

93-
def forward(self, model_pred, label, natoms, learning_rate, mae=False):
94-
"""Return loss on loss and force.
93+
def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
94+
"""Return loss on energy and force.
9595
96-
Args:
97-
- natoms: Tell atom count.
98-
- p_energy: Predicted energy of all atoms.
99-
- p_force: Predicted force per atom.
100-
- l_energy: Actual energy of all atoms.
101-
- l_force: Actual force per atom.
96+
Parameters
97+
----------
98+
input_dict : dict[str, torch.Tensor]
99+
Model inputs.
100+
model : torch.nn.Module
101+
Model to be used to output the predictions.
102+
label : dict[str, torch.Tensor]
103+
Labels.
104+
natoms : int
105+
The local atom number.
102106
103107
Returns
104108
-------
105-
- loss: Loss to minimize.
109+
model_pred: dict[str, torch.Tensor]
110+
Model predictions.
111+
loss: torch.Tensor
112+
Loss for model to minimize.
113+
more_loss: dict[str, torch.Tensor]
114+
Other losses for display.
106115
"""
116+
model_pred = model(**input_dict)
107117
coef = learning_rate / self.starter_learning_rate
108118
pref_e = self.limit_pref_e + (self.start_pref_e - self.limit_pref_e) * coef
109119
pref_f = self.limit_pref_f + (self.start_pref_f - self.limit_pref_f) * coef
@@ -200,7 +210,7 @@ def forward(self, model_pred, label, natoms, learning_rate, mae=False):
200210
more_loss["mae_v"] = mae_v.detach()
201211
if not self.inference:
202212
more_loss["rmse"] = torch.sqrt(loss.detach())
203-
return loss, more_loss
213+
return model_pred, loss, more_loss
204214

205215
@property
206216
def label_requirement(self) -> List[DataRequirementItem]:

deepmd/pt/loss/ener_spin.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,25 +63,30 @@ def __init__(
6363
self.use_l1_all = use_l1_all
6464
self.inference = inference
6565

66-
def forward(self, model_pred, label, natoms, learning_rate, mae=False):
66+
def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
6767
"""Return energy loss with magnetic labels.
6868
6969
Parameters
7070
----------
71-
model_pred : dict[str, torch.Tensor]
72-
Model predictions.
71+
input_dict : dict[str, torch.Tensor]
72+
Model inputs.
73+
model : torch.nn.Module
74+
Model to be used to output the predictions.
7375
label : dict[str, torch.Tensor]
7476
Labels.
7577
natoms : int
7678
The local atom number.
7779
7880
Returns
7981
-------
82+
model_pred: dict[str, torch.Tensor]
83+
Model predictions.
8084
loss: torch.Tensor
8185
Loss for model to minimize.
8286
more_loss: dict[str, torch.Tensor]
8387
Other losses for display.
8488
"""
89+
model_pred = model(**input_dict)
8590
coef = learning_rate / self.starter_learning_rate
8691
pref_e = self.limit_pref_e + (self.start_pref_e - self.limit_pref_e) * coef
8792
pref_fr = self.limit_pref_fr + (self.start_pref_fr - self.limit_pref_fr) * coef
@@ -175,7 +180,7 @@ def forward(self, model_pred, label, natoms, learning_rate, mae=False):
175180

176181
if not self.inference:
177182
more_loss["rmse"] = torch.sqrt(loss.detach())
178-
return loss, more_loss
183+
return model_pred, loss, more_loss
179184

180185
@property
181186
def label_requirement(self) -> List[DataRequirementItem]:

deepmd/pt/loss/loss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def __init__(self, **kwargs):
1919
"""Construct loss."""
2020
super().__init__()
2121

22-
def forward(self, model_pred, label, natoms, learning_rate):
22+
def forward(self, input_dict, model, label, natoms, learning_rate):
2323
"""Return loss ."""
2424
raise NotImplementedError
2525

deepmd/pt/loss/tensor.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,25 +63,30 @@ def __init__(
6363
"Can not assian zero weight both to `pref` and `pref_atomic`"
6464
)
6565

66-
def forward(self, model_pred, label, natoms, learning_rate=0.0, mae=False):
66+
def forward(self, input_dict, model, label, natoms, learning_rate=0.0, mae=False):
6767
"""Return loss on local and global tensors.
6868
6969
Parameters
7070
----------
71-
model_pred : dict[str, torch.Tensor]
72-
Model predictions.
71+
input_dict : dict[str, torch.Tensor]
72+
Model inputs.
73+
model : torch.nn.Module
74+
Model to be used to output the predictions.
7375
label : dict[str, torch.Tensor]
7476
Labels.
7577
natoms : int
7678
The local atom number.
7779
7880
Returns
7981
-------
82+
model_pred: dict[str, torch.Tensor]
83+
Model predictions.
8084
loss: torch.Tensor
8185
Loss for model to minimize.
8286
more_loss: dict[str, torch.Tensor]
8387
Other losses for display.
8488
"""
89+
model_pred = model(**input_dict)
8590
del learning_rate, mae
8691
loss = torch.zeros(1, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)[0]
8792
more_loss = {}
@@ -133,7 +138,7 @@ def forward(self, model_pred, label, natoms, learning_rate=0.0, mae=False):
133138
loss += self.global_weight * l2_global_loss
134139
rmse_global = l2_global_loss.sqrt() / atom_num
135140
more_loss[f"rmse_global_{self.tensor_name}"] = rmse_global.detach()
136-
return loss, more_loss
141+
return model_pred, loss, more_loss
137142

138143
@property
139144
def label_requirement(self) -> List[DataRequirementItem]:

deepmd/pt/train/training.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -696,8 +696,13 @@ def step(_step_id, task_key="Default"):
696696
module = (
697697
self.wrapper.module if dist.is_initialized() else self.wrapper
698698
)
699-
loss, more_loss = module.loss[task_key](
700-
model_pred,
699+
700+
def fake_model():
701+
return model_pred
702+
703+
_, loss, more_loss = module.loss[task_key](
704+
{},
705+
fake_model,
701706
label_dict,
702707
int(input_dict["atype"].shape[-1]),
703708
learning_rate=pref_lr,

deepmd/pt/train/wrapper.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -168,15 +168,20 @@ def forward(
168168
has_spin = has_spin()
169169
if has_spin:
170170
input_dict["spin"] = spin
171-
model_pred = self.model[task_key](**input_dict)
172-
natoms = atype.shape[-1]
173-
if not self.inference_only and not inference_only:
174-
loss, more_loss = self.loss[task_key](
175-
model_pred, label, natoms=natoms, learning_rate=cur_lr
171+
172+
if self.inference_only or inference_only:
173+
model_pred = self.model[task_key](**input_dict)
174+
return model_pred, None, None
175+
else:
176+
natoms = atype.shape[-1]
177+
model_pred, loss, more_loss = self.loss[task_key](
178+
input_dict,
179+
self.model[task_key],
180+
label,
181+
natoms=natoms,
182+
learning_rate=cur_lr,
176183
)
177184
return model_pred, loss, more_loss
178-
else:
179-
return model_pred, None, None
180185

181186
def set_extra_state(self, state: Dict):
182187
self.model_params = state["model_params"]

source/tests/pt/model/test_model.py

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -338,34 +338,33 @@ def test_consistency(self):
338338
batch["natoms"] = torch.tensor(
339339
batch["natoms_vec"], device=batch["coord"].device
340340
).unsqueeze(0)
341-
model_predict = my_model(
342-
batch["coord"].to(env.DEVICE),
343-
batch["atype"].to(env.DEVICE),
344-
batch["box"].to(env.DEVICE),
345-
do_atomic_virial=True,
346-
)
347-
model_predict_1 = my_model(
348-
batch["coord"].to(env.DEVICE),
349-
batch["atype"].to(env.DEVICE),
350-
batch["box"].to(env.DEVICE),
351-
do_atomic_virial=False,
341+
model_input = {
342+
"coord": batch["coord"].to(env.DEVICE),
343+
"atype": batch["atype"].to(env.DEVICE),
344+
"box": batch["box"].to(env.DEVICE),
345+
"do_atomic_virial": True,
346+
}
347+
model_input_1 = {
348+
"coord": batch["coord"].to(env.DEVICE),
349+
"atype": batch["atype"].to(env.DEVICE),
350+
"box": batch["box"].to(env.DEVICE),
351+
"do_atomic_virial": False,
352+
}
353+
label = {
354+
"energy": batch["energy"].to(env.DEVICE),
355+
"force": batch["force"].to(env.DEVICE),
356+
}
357+
cur_lr = my_lr.value(self.wanted_step)
358+
model_predict, loss, _ = my_loss(
359+
model_input, my_model, label, int(batch["natoms"][0, 0]), cur_lr
352360
)
361+
model_predict_1 = my_model(**model_input_1)
353362
p_energy, p_force, p_virial, p_atomic_virial = (
354363
model_predict["energy"],
355364
model_predict["force"],
356365
model_predict["virial"],
357366
model_predict["atom_virial"],
358367
)
359-
cur_lr = my_lr.value(self.wanted_step)
360-
model_pred = {
361-
"energy": p_energy,
362-
"force": p_force,
363-
}
364-
label = {
365-
"energy": batch["energy"].to(env.DEVICE),
366-
"force": batch["force"].to(env.DEVICE),
367-
}
368-
loss, _ = my_loss(model_pred, label, int(batch["natoms"][0, 0]), cur_lr)
369368
np.testing.assert_allclose(
370369
head_dict["energy"], p_energy.view(-1).cpu().detach().numpy()
371370
)

source/tests/pt/test_loss.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -171,8 +171,13 @@ def test_consistency(self):
171171
self.start_pref_v,
172172
self.limit_pref_v,
173173
)
174-
my_loss, my_more_loss = mine(
175-
self.model_pred,
174+
175+
def fake_model():
176+
return self.model_pred
177+
178+
_, my_loss, my_more_loss = mine(
179+
{},
180+
fake_model,
176181
self.label,
177182
self.nloc,
178183
self.cur_lr,
@@ -345,8 +350,13 @@ def test_consistency(self):
345350
self.start_pref_fm,
346351
self.limit_pref_fm,
347352
)
348-
my_loss, my_more_loss = mine(
349-
self.model_pred,
353+
354+
def fake_model():
355+
return self.model_pred
356+
357+
_, my_loss, my_more_loss = mine(
358+
{},
359+
fake_model,
350360
self.label,
351361
self.nloc_tf, # use tf natoms pref
352362
self.cur_lr,

0 commit comments

Comments
 (0)