@@ -1792,7 +1792,7 @@ def serialize(self) -> dict:
17921792 """
17931793 data = {
17941794 "@class" : "RepformerLayer" ,
1795- "@version" : 1 ,
1795+ "@version" : 2 ,
17961796 "rcut" : self .rcut ,
17971797 "rcut_smth" : self .rcut_smth ,
17981798 "sel" : self .sel ,
@@ -1877,9 +1877,11 @@ def serialize(self) -> dict:
18771877 if self .update_style == "res_residual" :
18781878 data .update (
18791879 {
1880- "g1_residual" : [to_numpy_array (aa ) for aa in self .g1_residual ],
1881- "g2_residual" : [to_numpy_array (aa ) for aa in self .g2_residual ],
1882- "h2_residual" : [to_numpy_array (aa ) for aa in self .h2_residual ],
1880+ "@variables" : {
1881+ "g1_residual" : [to_numpy_array (aa ) for aa in self .g1_residual ],
1882+ "g2_residual" : [to_numpy_array (aa ) for aa in self .g2_residual ],
1883+ "h2_residual" : [to_numpy_array (aa ) for aa in self .h2_residual ],
1884+ }
18831885 }
18841886 )
18851887 return data
@@ -1894,7 +1896,7 @@ def deserialize(cls, data: dict) -> "RepformerLayer":
18941896 The dict to deserialize from.
18951897 """
18961898 data = data .copy ()
1897- check_version_compatibility (data .pop ("@version" ), 1 , 1 )
1899+ check_version_compatibility (data .pop ("@version" ), 2 , 1 )
18981900 data .pop ("@class" )
18991901 linear1 = data .pop ("linear1" )
19001902 update_chnnl_2 = data ["update_chnnl_2" ]
@@ -1915,9 +1917,10 @@ def deserialize(cls, data: dict) -> "RepformerLayer":
19151917 attn2_ev_apply = data .pop ("attn2_ev_apply" , None )
19161918 loc_attn = data .pop ("loc_attn" , None )
19171919 g1_self_mlp = data .pop ("g1_self_mlp" , None )
1918- g1_residual = data .pop ("g1_residual" , [])
1919- g2_residual = data .pop ("g2_residual" , [])
1920- h2_residual = data .pop ("h2_residual" , [])
1920+ variables = data .pop ("@variables" , {})
1921+ g1_residual = variables .get ("g1_residual" , data .pop ("g1_residual" , []))
1922+ g2_residual = variables .get ("g2_residual" , data .pop ("g2_residual" , []))
1923+ h2_residual = variables .get ("h2_residual" , data .pop ("h2_residual" , []))
19211924
19221925 obj = cls (** data )
19231926 obj .linear1 = NativeLayer .deserialize (linear1 )
0 commit comments