@@ -172,12 +172,39 @@ class AggregationFunction(str, Enum):
172172
173173@PublicAPI (stability = "alpha" )
174174class AutoscalingPolicy (BaseModel ):
175- name : Union [str , Callable ] = Field (
175+ # Cloudpickled policy definition.
176+ _serialized_policy_def : bytes = PrivateAttr (default = b"" )
177+
178+ policy_function : Union [str , Callable ] = Field (
176179 default = DEFAULT_AUTOSCALING_POLICY_NAME ,
177- description = "Name of the policy function or the import path of the policy . "
178- "Will be the concatenation of the policy module and the policy name if user passed a callable. " ,
180+ description = "Policy function can be a string import path or a function callable . "
181+ "If it's a string import path, it must be of the form `path.to.module:function_name`. " ,
179182 )
180183
184+ def __init__ (self , ** kwargs ):
185+ super ().__init__ (** kwargs )
186+ self .serialize_policy ()
187+
188+ def serialize_policy (self ) -> None :
189+ """Serialize policy with cloudpickle.
190+
191+ Import the policy if it's passed in as a string import path. Then cloudpickle
192+ the policy and set `serialized_policy_def` if it's empty.
193+ """
194+ policy_path = self .policy_function
195+
196+ if isinstance (policy_path , Callable ):
197+ policy_path = f"{ policy_path .__module__ } .{ policy_path .__name__ } "
198+
199+ if not self ._serialized_policy_def :
200+ self ._serialized_policy_def = cloudpickle .dumps (import_attr (policy_path ))
201+
202+ self .policy_function = policy_path
203+
204+ def get_policy (self ) -> Callable :
205+ """Deserialize policy from cloudpickled bytes."""
206+ return cloudpickle .loads (self ._serialized_policy_def )
207+
181208
182209@PublicAPI (stability = "stable" )
183210class AutoscalingConfig (BaseModel ):
@@ -247,9 +274,6 @@ class AutoscalingConfig(BaseModel):
247274 description = "Function used to aggregate metrics across a time window." ,
248275 )
249276
250- # Cloudpickled policy definition.
251- _serialized_policy_def : bytes = PrivateAttr (default = b"" )
252-
253277 # Autoscaling policy. This policy is deployment scoped. Defaults to the request-based autoscaler.
254278 policy : AutoscalingPolicy = Field (
255279 default_factory = AutoscalingPolicy ,
@@ -298,27 +322,6 @@ def aggregation_function_valid(cls, v: Union[str, AggregationFunction]):
298322 return v
299323 return AggregationFunction (str (v ).lower ())
300324
301- def __init__ (self , ** kwargs ):
302- super ().__init__ (** kwargs )
303- self .serialize_policy ()
304-
305- def serialize_policy (self ) -> None :
306- """Serialize policy with cloudpickle.
307-
308- Import the policy if it's passed in as a string import path. Then cloudpickle
309- the policy and set `serialized_policy_def` if it's empty.
310- """
311- policy = self .policy
312- policy_name = policy .name
313-
314- if isinstance (policy_name , Callable ):
315- policy_name = f"{ policy_name .__module__ } .{ policy_name .__name__ } "
316-
317- if not self ._serialized_policy_def :
318- self ._serialized_policy_def = cloudpickle .dumps (import_attr (policy_name ))
319-
320- self .policy = AutoscalingPolicy (name = policy_name )
321-
322325 @classmethod
323326 def default (cls ):
324327 return cls (
@@ -327,10 +330,6 @@ def default(cls):
327330 max_replicas = 100 ,
328331 )
329332
330- def get_policy (self ) -> Callable :
331- """Deserialize policy from cloudpickled bytes."""
332- return cloudpickle .loads (self ._serialized_policy_def )
333-
334333 def get_upscaling_factor (self ) -> PositiveFloat :
335334 if self .upscaling_factor :
336335 return self .upscaling_factor
0 commit comments