Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 22 additions & 13 deletions deepmd/pd/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,27 @@
)

backend = "CINN" if CINN else None
# NOTE: This is a trick to decide the right input_spec for wrapper.forward
_, label_dict, _ = self.get_data(is_train=True)

Check warning on line 611 in deepmd/pd/train/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/train/training.py#L611

Added line #L611 was not covered by tests

# Define specification templates
spec_templates = {

Check warning on line 614 in deepmd/pd/train/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/train/training.py#L614

Added line #L614 was not covered by tests
"find_box": np.float32(1.0),
"find_coord": np.float32(1.0),
"find_numb_copy": np.float32(0.0),
"numb_copy": static.InputSpec([1, 1], "int64", name="numb_copy"),
"find_energy": np.float32(1.0),
"energy": static.InputSpec([1, 1], "float64", name="energy"),
"find_force": np.float32(1.0),
"force": static.InputSpec([1, -1, 3], "float64", name="force"),
"find_virial": np.float32(0.0),
"virial": static.InputSpec([1, 9], "float64", name="virial"),
"natoms": static.InputSpec([1, -1], "int32", name="natoms"),
}
# Build spec only for keys present in sample data
label_dict_spec = {

Check warning on line 628 in deepmd/pd/train/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/train/training.py#L628

Added line #L628 was not covered by tests
k: spec_templates[k] for k in label_dict.keys() if k in spec_templates
}
self.wrapper.forward = jit.to_static(
backend=backend,
input_spec=[
Expand All @@ -615,19 +636,7 @@
None, # spin
static.InputSpec([1, 9], "float64", name="box"), # box
static.InputSpec([], "float64", name="cur_lr"), # cur_lr
{
"find_box": np.float32(1.0),
"find_coord": np.float32(1.0),
"find_numb_copy": np.float32(0.0),
"numb_copy": static.InputSpec(
[1, 1], "int64", name="numb_copy"
),
"find_energy": np.float32(1.0),
"energy": static.InputSpec([1, 1], "float64", name="energy"),
"find_force": np.float32(1.0),
"force": static.InputSpec([1, -1, 3], "float64", name="force"),
"natoms": static.InputSpec([1, -1], "int32", name="natoms"),
}, # label,
label_dict_spec, # label,
# None, # task_key
# False, # inference_only
# False, # do_atomic_virial
Expand Down