11import numbers
22import warnings
33from bisect import bisect_right
4- from typing import Any , Callable , List , Sequence , Tuple , Union
4+ from collections .abc import Callable , Sequence
5+ from typing import Any
56
67from ignite .engine import CallableEventWithFilter , Engine , Events , EventsList
78from ignite .handlers .param_scheduler import BaseParamScheduler
@@ -32,7 +33,7 @@ def __init__(self, param_name: str, save_history: bool = False, create_new: bool
3233 def attach (
3334 self ,
3435 engine : Engine ,
35- event : Union [ str , Events , CallableEventWithFilter , EventsList ] = Events .ITERATION_COMPLETED ,
36+ event : str | Events | CallableEventWithFilter | EventsList = Events .ITERATION_COMPLETED ,
3637 ) -> None :
3738 """Attach the handler to the engine. Once the handler is attached, the ``Engine.state`` will have a new
3839 attribute with the name ``param_name``. Then the current value of the parameter can be retrieved from
@@ -73,7 +74,7 @@ def __call__(self, engine: Engine) -> None:
7374 engine .state .param_history [self .param_name ].append (value ) # type: ignore[attr-defined]
7475
7576 @classmethod
76- def simulate_values (cls , num_events : int , ** scheduler_kwargs : Any ) -> List [ List [int ]]:
77+ def simulate_values (cls , num_events : int , ** scheduler_kwargs : Any ) -> list [ list [int ]]:
7778 """Method to simulate scheduled engine state parameter values during `num_events` events.
7879
7980 Args:
@@ -185,7 +186,7 @@ def print_param():
185186
186187 def __init__ (
187188 self ,
188- lambda_obj : Callable [[int ], Union [ List [ float ], float ] ],
189+ lambda_obj : Callable [[int ], list [ float ] | float ],
189190 param_name : str ,
190191 save_history : bool = False ,
191192 create_new : bool = False ,
@@ -198,7 +199,7 @@ def __init__(
198199 self .lambda_obj = lambda_obj
199200 self ._state_attrs += ["lambda_obj" ]
200201
201- def get_param (self ) -> Union [ List [ float ], float ] :
202+ def get_param (self ) -> list [ float ] | float :
202203 return self .lambda_obj (self .event_index )
203204
204205
@@ -267,7 +268,7 @@ def print_param():
267268
268269 def __init__ (
269270 self ,
270- milestones_values : List [ Tuple [int , float ]],
271+ milestones_values : list [ tuple [int , float ]],
271272 param_name : str ,
272273 save_history : bool = False ,
273274 create_new : bool = False ,
@@ -283,8 +284,8 @@ def __init__(
283284 f"Argument milestones_values should be with at least one value, but given { milestones_values } "
284285 )
285286
286- values : List [float ] = []
287- milestones : List [int ] = []
287+ values : list [float ] = []
288+ milestones : list [int ] = []
288289 for pair in milestones_values :
289290 if not isinstance (pair , tuple ) or len (pair ) != 2 :
290291 raise ValueError ("Argument milestones_values should be a list of pairs (milestone, param_value)" )
@@ -303,7 +304,7 @@ def __init__(
303304 self ._index = 0
304305 self ._state_attrs += ["values" , "milestones" , "_index" ]
305306
306- def _get_start_end (self ) -> Tuple [int , int , float , float ]:
307+ def _get_start_end (self ) -> tuple [int , int , float , float ]:
307308 if self .milestones [0 ] > self .event_index :
308309 return self .event_index - 1 , self .event_index , self .values [0 ], self .values [0 ]
309310 elif self .milestones [- 1 ] <= self .event_index :
@@ -319,7 +320,7 @@ def _get_start_end(self) -> Tuple[int, int, float, float]:
319320 self ._index += 1
320321 return self ._get_start_end ()
321322
322- def get_param (self ) -> Union [ List [ float ], float ] :
323+ def get_param (self ) -> list [ float ] | float :
323324 start_index , end_index , start_value , end_value = self ._get_start_end ()
324325 return start_value + (end_value - start_value ) * (self .event_index - start_index ) / (end_index - start_index )
325326
@@ -387,7 +388,7 @@ def __init__(
387388 self .gamma = gamma
388389 self ._state_attrs += ["initial_value" , "gamma" ]
389390
390- def get_param (self ) -> Union [ List [ float ], float ] :
391+ def get_param (self ) -> list [ float ] | float :
391392 return self .initial_value * self .gamma ** self .event_index
392393
393394
@@ -466,7 +467,7 @@ def __init__(
466467 self .step_size = step_size
467468 self ._state_attrs += ["initial_value" , "gamma" , "step_size" ]
468469
469- def get_param (self ) -> Union [ List [ float ], float ] :
470+ def get_param (self ) -> list [ float ] | float :
470471 return self .initial_value * self .gamma ** (self .event_index // self .step_size )
471472
472473
@@ -542,7 +543,7 @@ def __init__(
542543 self ,
543544 initial_value : float ,
544545 gamma : float ,
545- milestones : List [int ],
546+ milestones : list [int ],
546547 param_name : str ,
547548 save_history : bool = False ,
548549 create_new : bool = False ,
@@ -553,5 +554,5 @@ def __init__(
553554 self .milestones = milestones
554555 self ._state_attrs += ["initial_value" , "gamma" , "milestones" ]
555556
556- def get_param (self ) -> Union [ List [ float ], float ] :
557+ def get_param (self ) -> list [ float ] | float :
557558 return self .initial_value * self .gamma ** bisect_right (self .milestones , self .event_index )
0 commit comments