1313 BaseModel ,
1414 get_model ,
1515)
16- from deepmd .jax .utils .network import (
17- ArrayAPIParam ,
18- )
1916
2017
2118def deserialize_to_file (model_file : str , data : dict ) -> None :
@@ -31,14 +28,14 @@ def deserialize_to_file(model_file: str, data: dict) -> None:
3128 if model_file .endswith (".jax" ):
3229 model = BaseModel .deserialize (data ["model" ])
3330 model_def_script = data ["model_def_script" ]
34- state = nnx .state (model , ArrayAPIParam )
31+ _ , state = nnx .split (model )
3532 with ocp .Checkpointer (
3633 ocp .CompositeCheckpointHandler ("state" , "model_def_script" )
3734 ) as checkpointer :
3835 checkpointer .save (
3936 Path (model_file ).absolute (),
4037 ocp .args .Composite (
41- state = ocp .args .StandardSave (state ),
38+ state = ocp .args .StandardSave (state . to_pure_dict () ),
4239 model_def_script = ocp .args .JsonSave (model_def_script ),
4340 ),
4441 )
@@ -71,9 +68,22 @@ def serialize_from_file(model_file: str) -> dict:
7168 ),
7269 )
7370 state = data .state
71+
72+ # convert str "1" to int 1 key
73+ def convert_str_to_int_key (item : dict ):
74+ for key , value in item .copy ().items ():
75+ if isinstance (value , dict ):
76+ convert_str_to_int_key (value )
77+ if key .isdigit ():
78+ item [int (key )] = item .pop (key )
79+
80+ convert_str_to_int_key (state )
81+
7482 model_def_script = data .model_def_script
75- model = get_model (model_def_script )
76- nnx .update (model , state )
83+ abstract_model = get_model (model_def_script )
84+ graphdef , abstract_state = nnx .split (abstract_model )
85+ abstract_state .replace_by_pure_dict (state )
86+ model = nnx .merge (graphdef , abstract_state )
7787 model_dict = model .serialize ()
7888 data = {
7989 "backend" : "JAX" ,
0 commit comments