Skip to content

Commit 261c802

Browse files
pt: add necessary jit.export (#3337)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 91049df commit 261c802

5 files changed

Lines changed: 5 additions & 0 deletions

File tree

deepmd/pt/model/model/dipole_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def forward(
5555
model_predict["updated_coord"] += coord
5656
return model_predict
5757

58+
@torch.jit.export
5859
def forward_lower(
5960
self,
6061
extended_coord,

deepmd/pt/model/model/dp_zbl_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def forward(
5858
model_predict["force"] = model_ret["dforce"]
5959
return model_predict
6060

61+
@torch.jit.export
6162
def forward_lower(
6263
self,
6364
extended_coord,

deepmd/pt/model/model/ener_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def forward(
5757
model_predict["updated_coord"] += coord
5858
return model_predict
5959

60+
@torch.jit.export
6061
def forward_lower(
6162
self,
6263
extended_coord,

deepmd/pt/model/model/make_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@ def forward_common_lower(
202202
)
203203
return model_predict
204204

205+
@torch.jit.export
205206
def format_nlist(
206207
self,
207208
extended_coord: torch.Tensor,

deepmd/pt/model/model/polar_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def forward(
4747
model_predict["updated_coord"] += coord
4848
return model_predict
4949

50+
@torch.jit.export
5051
def forward_lower(
5152
self,
5253
extended_coord,

0 commit comments

Comments
 (0)