Skip to content

Commit fb3df8b

Browse files
committed
bugfix
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
1 parent 9b571d1 commit fb3df8b

9 files changed

Lines changed: 52 additions & 8 deletions

File tree

deepmd/jax/atomic_model/base_atomic_model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22
from deepmd.jax.common import (
3+
ArrayAPIVariable,
34
to_jax_array,
45
)
56
from deepmd.jax.utils.exclude_mask import (
@@ -11,6 +12,8 @@
1112
def base_atomic_model_set_attr(name, value):
1213
if name in {"out_bias", "out_std"}:
1314
value = to_jax_array(value)
15+
if value is not None:
16+
value = ArrayAPIVariable(value)
1417
elif name == "pair_excl" and value is not None:
1518
value = PairExcludeMask(value.ntypes, value.exclude_types)
1619
elif name == "atom_excl" and value is not None:

deepmd/jax/common.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,3 +81,17 @@ def __setattr__(self, name: str, value: Any) -> None:
8181
return super().__setattr__(name, value)
8282

8383
return FlaxModule
84+
85+
86+
class ArrayAPIVariable(nnx.Variable):
87+
def __array__(self, *args, **kwargs):
88+
return self.value.__array__(*args, **kwargs)
89+
90+
def __array_namespace__(self, *args, **kwargs):
91+
return self.value.__array_namespace__(*args, **kwargs)
92+
93+
def __dlpack__(self, *args, **kwargs):
94+
return self.value.__dlpack__(*args, **kwargs)
95+
96+
def __dlpack_device__(self, *args, **kwargs):
97+
return self.value.__dlpack_device__(*args, **kwargs)

deepmd/jax/descriptor/dpa1.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
NeighborGatedAttentionLayer as NeighborGatedAttentionLayerDP,
1414
)
1515
from deepmd.jax.common import (
16+
ArrayAPIVariable,
1617
flax_module,
1718
to_jax_array,
1819
)
@@ -65,6 +66,8 @@ class DescrptBlockSeAtten(DescrptBlockSeAttenDP):
6566
def __setattr__(self, name: str, value: Any) -> None:
6667
if name in {"mean", "stddev"}:
6768
value = to_jax_array(value)
69+
if value is not None:
70+
value = ArrayAPIVariable(value)
6871
elif name in {"embeddings", "embeddings_strip"}:
6972
if value is not None:
7073
value = NetworkCollection.deserialize(value.serialize())

deepmd/jax/descriptor/se_e2_a.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from deepmd.dpmodel.descriptor.se_e2_a import DescrptSeAArrayAPI as DescrptSeADP
77
from deepmd.jax.common import (
8+
ArrayAPIVariable,
89
flax_module,
910
to_jax_array,
1011
)
@@ -26,6 +27,8 @@ class DescrptSeA(DescrptSeADP):
2627
def __setattr__(self, name: str, value: Any) -> None:
2728
if name in {"dstd", "davg"}:
2829
value = to_jax_array(value)
30+
if value is not None:
31+
value = ArrayAPIVariable(value)
2932
elif name in {"embeddings"}:
3033
if value is not None:
3134
value = NetworkCollection.deserialize(value.serialize())

deepmd/jax/fitting/fitting.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from deepmd.dpmodel.fitting.dos_fitting import DOSFittingNet as DOSFittingNetDP
77
from deepmd.dpmodel.fitting.ener_fitting import EnergyFittingNet as EnergyFittingNetDP
88
from deepmd.jax.common import (
9+
ArrayAPIVariable,
910
flax_module,
1011
to_jax_array,
1112
)
@@ -29,6 +30,8 @@ def setattr_for_general_fitting(name: str, value: Any) -> Any:
2930
"aparam_inv_std",
3031
}:
3132
value = to_jax_array(value)
33+
if value is not None:
34+
value = ArrayAPIVariable(value)
3235
elif name == "emask":
3336
value = AtomExcludeMask(value.ntypes, value.exclude_types)
3437
elif name == "nets":

deepmd/jax/utils/exclude_mask.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from deepmd.dpmodel.utils.exclude_mask import AtomExcludeMask as AtomExcludeMaskDP
77
from deepmd.dpmodel.utils.exclude_mask import PairExcludeMask as PairExcludeMaskDP
88
from deepmd.jax.common import (
9+
ArrayAPIVariable,
910
flax_module,
1011
to_jax_array,
1112
)
@@ -16,6 +17,8 @@ class AtomExcludeMask(AtomExcludeMaskDP):
1617
def __setattr__(self, name: str, value: Any) -> None:
1718
if name in {"type_mask"}:
1819
value = to_jax_array(value)
20+
if value is not None:
21+
value = ArrayAPIVariable(value)
1922
return super().__setattr__(name, value)
2023

2124

@@ -24,4 +27,6 @@ class PairExcludeMask(PairExcludeMaskDP):
2427
def __setattr__(self, name: str, value: Any) -> None:
2528
if name in {"type_mask"}:
2629
value = to_jax_array(value)
30+
if value is not None:
31+
value = ArrayAPIVariable(value)
2732
return super().__setattr__(name, value)

deepmd/jax/utils/serialization.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,6 @@
1313
BaseModel,
1414
get_model,
1515
)
16-
from deepmd.jax.utils.network import (
17-
ArrayAPIParam,
18-
)
1916

2017

2118
def deserialize_to_file(model_file: str, data: dict) -> None:
@@ -31,14 +28,14 @@ def deserialize_to_file(model_file: str, data: dict) -> None:
3128
if model_file.endswith(".jax"):
3229
model = BaseModel.deserialize(data["model"])
3330
model_def_script = data["model_def_script"]
34-
state = nnx.state(model, ArrayAPIParam)
31+
_, state = nnx.split(model)
3532
with ocp.Checkpointer(
3633
ocp.CompositeCheckpointHandler("state", "model_def_script")
3734
) as checkpointer:
3835
checkpointer.save(
3936
Path(model_file).absolute(),
4037
ocp.args.Composite(
41-
state=ocp.args.StandardSave(state),
38+
state=ocp.args.StandardSave(state.to_pure_dict()),
4239
model_def_script=ocp.args.JsonSave(model_def_script),
4340
),
4441
)
@@ -71,9 +68,22 @@ def serialize_from_file(model_file: str) -> dict:
7168
),
7269
)
7370
state = data.state
71+
72+
# convert str "1" to int 1 key
73+
def convert_str_to_int_key(item: dict):
74+
for key, value in item.copy().items():
75+
if isinstance(value, dict):
76+
convert_str_to_int_key(value)
77+
if key.isdigit():
78+
item[int(key)] = item.pop(key)
79+
80+
convert_str_to_int_key(state)
81+
7482
model_def_script = data.model_def_script
75-
model = get_model(model_def_script)
76-
nnx.update(model, state)
83+
abstract_model = get_model(model_def_script)
84+
graphdef, abstract_state = nnx.split(abstract_model)
85+
abstract_state.replace_by_pure_dict(state)
86+
model = nnx.merge(graphdef, abstract_state)
7787
model_dict = model.serialize()
7888
data = {
7989
"backend": "JAX",

deepmd/jax/utils/type_embed.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from deepmd.dpmodel.utils.type_embed import TypeEmbedNet as TypeEmbedNetDP
77
from deepmd.jax.common import (
8+
ArrayAPIVariable,
89
flax_module,
910
to_jax_array,
1011
)
@@ -18,6 +19,8 @@ class TypeEmbedNet(TypeEmbedNetDP):
1819
def __setattr__(self, name: str, value: Any) -> None:
1920
if name in {"econf_tebd"}:
2021
value = to_jax_array(value)
22+
if value is not None:
23+
value = ArrayAPIVariable(value)
2124
if name in {"embedding_net"}:
2225
value = EmbeddingNet.deserialize(value.serialize())
2326
return super().__setattr__(name, value)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ cu12 = [
137137
]
138138
jax = [
139139
'jax>=0.4.33;python_version>="3.10"',
140-
'flax>=0.8.0;python_version>="3.10"',
140+
'flax>=0.10.0;python_version>="3.10"',
141141
'orbax-checkpoint;python_version>="3.10"',
142142
# The pinning of ml_dtypes may conflict with TF
143143
# 'jax-ai-stack;python_version>="3.10"',

0 commit comments

Comments
 (0)