Skip to content

Commit 6d303ee

Browse files
committed
fix(pt): fix seed in dpmodel fitting
1 parent 17cdcb0 commit 6d303ee

5 files changed

Lines changed: 7 additions & 7 deletions

File tree

deepmd/dpmodel/fitting/dipole_fitting.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,10 +108,8 @@ def __init__(
108108
c_differentiable: bool = True,
109109
type_map: Optional[List[str]] = None,
110110
old_impl=False,
111-
# not used
112111
seed: Optional[Union[int, List[int]]] = None,
113112
):
114-
# seed, uniform_seed are not included
115113
if tot_ener_zero:
116114
raise NotImplementedError("tot_ener_zero is not implemented")
117115
if spin is not None:
@@ -143,6 +141,7 @@ def __init__(
143141
mixed_types=mixed_types,
144142
exclude_types=exclude_types,
145143
type_map=type_map,
144+
seed=seed,
146145
)
147146
self.old_impl = False
148147

deepmd/dpmodel/fitting/dos_fitting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ def __init__(
4545
mixed_types: bool = False,
4646
exclude_types: List[int] = [],
4747
type_map: Optional[List[str]] = None,
48-
# not used
4948
seed: Optional[Union[int, List[int]]] = None,
5049
):
5150
if bias_dos is not None:
@@ -69,6 +68,7 @@ def __init__(
6968
mixed_types=mixed_types,
7069
exclude_types=exclude_types,
7170
type_map=type_map,
71+
seed=seed,
7272
)
7373

7474
@classmethod

deepmd/dpmodel/fitting/ener_fitting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ def __init__(
4646
mixed_types: bool = False,
4747
exclude_types: List[int] = [],
4848
type_map: Optional[List[str]] = None,
49-
# not used
5049
seed: Optional[Union[int, List[int]]] = None,
5150
):
5251
super().__init__(
@@ -70,6 +69,7 @@ def __init__(
7069
mixed_types=mixed_types,
7170
exclude_types=exclude_types,
7271
type_map=type_map,
72+
seed=seed,
7373
)
7474

7575
@classmethod

deepmd/dpmodel/fitting/invar_fitting.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
Dict,
66
List,
77
Optional,
8+
Union,
89
)
910

1011
import numpy as np
@@ -134,8 +135,8 @@ def __init__(
134135
mixed_types: bool = True,
135136
exclude_types: List[int] = [],
136137
type_map: Optional[List[str]] = None,
138+
seed: Optional[Union[int, List[int]]] = None,
137139
):
138-
# seed, uniform_seed are not included
139140
if tot_ener_zero:
140141
raise NotImplementedError("tot_ener_zero is not implemented")
141142
if spin is not None:
@@ -172,6 +173,7 @@ def __init__(
172173
if atom_ener is None or len([x for x in atom_ener if x is not None]) == 0
173174
else [x is not None for x in atom_ener],
174175
type_map=type_map,
176+
seed=seed,
175177
)
176178

177179
def serialize(self) -> dict:

deepmd/dpmodel/fitting/polarizability_fitting.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,10 +114,8 @@ def __init__(
114114
scale: Optional[List[float]] = None,
115115
shift_diag: bool = True,
116116
type_map: Optional[List[str]] = None,
117-
# not used
118117
seed: Optional[Union[int, List[int]]] = None,
119118
):
120-
# seed, uniform_seed are not included
121119
if tot_ener_zero:
122120
raise NotImplementedError("tot_ener_zero is not implemented")
123121
if spin is not None:
@@ -167,6 +165,7 @@ def __init__(
167165
mixed_types=mixed_types,
168166
exclude_types=exclude_types,
169167
type_map=type_map,
168+
seed=seed,
170169
)
171170
self.old_impl = False
172171

0 commit comments

Comments
 (0)