Skip to content

Commit 4bb7da1

Browse files
Manimaran-techManimaran
andauthored
Replace old-style typing hints in handlers/state_param_scheduler.py (#3592)
## Description Modernize type hints in `ignite/handlers/state_param_scheduler.py` to use Python 3.10+ syntax. Related to #3481 ### Changes made: - `Union[A, B]` → `A | B` - `List[...]` → `list[...]` - `Tuple[...]` → `tuple[...]` - Removed unused imports (`List`, `Tuple`, `Union`) from `typing` ### Files changed: - `ignite/handlers/state_param_scheduler.py` All changes are limited to type annotations in function signatures and local variable declarations across the 6 scheduler classes (`StateParamScheduler`, `LambdaStateScheduler`, `PiecewiseLinearStateScheduler`, `ExpStateScheduler`, `StepStateScheduler`, `MultiStepStateScheduler`). No functional changes. --------- Co-authored-by: Manimaran <manimarantech@gmail.com>
1 parent ca3371f commit 4bb7da1

File tree

1 file changed

+15
-14
lines changed

1 file changed

+15
-14
lines changed

ignite/handlers/state_param_scheduler.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import numbers
22
import warnings
33
from bisect import bisect_right
4-
from typing import Any, Callable, List, Sequence, Tuple, Union
4+
from collections.abc import Callable, Sequence
5+
from typing import Any
56

67
from ignite.engine import CallableEventWithFilter, Engine, Events, EventsList
78
from ignite.handlers.param_scheduler import BaseParamScheduler
@@ -32,7 +33,7 @@ def __init__(self, param_name: str, save_history: bool = False, create_new: bool
3233
def attach(
3334
self,
3435
engine: Engine,
35-
event: Union[str, Events, CallableEventWithFilter, EventsList] = Events.ITERATION_COMPLETED,
36+
event: str | Events | CallableEventWithFilter | EventsList = Events.ITERATION_COMPLETED,
3637
) -> None:
3738
"""Attach the handler to the engine. Once the handler is attached, the ``Engine.state`` will have a new
3839
attribute with the name ``param_name``. Then the current value of the parameter can be retrieved from
@@ -73,7 +74,7 @@ def __call__(self, engine: Engine) -> None:
7374
engine.state.param_history[self.param_name].append(value) # type: ignore[attr-defined]
7475

7576
@classmethod
76-
def simulate_values(cls, num_events: int, **scheduler_kwargs: Any) -> List[List[int]]:
77+
def simulate_values(cls, num_events: int, **scheduler_kwargs: Any) -> list[list[int]]:
7778
"""Method to simulate scheduled engine state parameter values during `num_events` events.
7879
7980
Args:
@@ -185,7 +186,7 @@ def print_param():
185186

186187
def __init__(
187188
self,
188-
lambda_obj: Callable[[int], Union[List[float], float]],
189+
lambda_obj: Callable[[int], list[float] | float],
189190
param_name: str,
190191
save_history: bool = False,
191192
create_new: bool = False,
@@ -198,7 +199,7 @@ def __init__(
198199
self.lambda_obj = lambda_obj
199200
self._state_attrs += ["lambda_obj"]
200201

201-
def get_param(self) -> Union[List[float], float]:
202+
def get_param(self) -> list[float] | float:
202203
return self.lambda_obj(self.event_index)
203204

204205

@@ -267,7 +268,7 @@ def print_param():
267268

268269
def __init__(
269270
self,
270-
milestones_values: List[Tuple[int, float]],
271+
milestones_values: list[tuple[int, float]],
271272
param_name: str,
272273
save_history: bool = False,
273274
create_new: bool = False,
@@ -283,8 +284,8 @@ def __init__(
283284
f"Argument milestones_values should be with at least one value, but given {milestones_values}"
284285
)
285286

286-
values: List[float] = []
287-
milestones: List[int] = []
287+
values: list[float] = []
288+
milestones: list[int] = []
288289
for pair in milestones_values:
289290
if not isinstance(pair, tuple) or len(pair) != 2:
290291
raise ValueError("Argument milestones_values should be a list of pairs (milestone, param_value)")
@@ -303,7 +304,7 @@ def __init__(
303304
self._index = 0
304305
self._state_attrs += ["values", "milestones", "_index"]
305306

306-
def _get_start_end(self) -> Tuple[int, int, float, float]:
307+
def _get_start_end(self) -> tuple[int, int, float, float]:
307308
if self.milestones[0] > self.event_index:
308309
return self.event_index - 1, self.event_index, self.values[0], self.values[0]
309310
elif self.milestones[-1] <= self.event_index:
@@ -319,7 +320,7 @@ def _get_start_end(self) -> Tuple[int, int, float, float]:
319320
self._index += 1
320321
return self._get_start_end()
321322

322-
def get_param(self) -> Union[List[float], float]:
323+
def get_param(self) -> list[float] | float:
323324
start_index, end_index, start_value, end_value = self._get_start_end()
324325
return start_value + (end_value - start_value) * (self.event_index - start_index) / (end_index - start_index)
325326

@@ -387,7 +388,7 @@ def __init__(
387388
self.gamma = gamma
388389
self._state_attrs += ["initial_value", "gamma"]
389390

390-
def get_param(self) -> Union[List[float], float]:
391+
def get_param(self) -> list[float] | float:
391392
return self.initial_value * self.gamma**self.event_index
392393

393394

@@ -466,7 +467,7 @@ def __init__(
466467
self.step_size = step_size
467468
self._state_attrs += ["initial_value", "gamma", "step_size"]
468469

469-
def get_param(self) -> Union[List[float], float]:
470+
def get_param(self) -> list[float] | float:
470471
return self.initial_value * self.gamma ** (self.event_index // self.step_size)
471472

472473

@@ -542,7 +543,7 @@ def __init__(
542543
self,
543544
initial_value: float,
544545
gamma: float,
545-
milestones: List[int],
546+
milestones: list[int],
546547
param_name: str,
547548
save_history: bool = False,
548549
create_new: bool = False,
@@ -553,5 +554,5 @@ def __init__(
553554
self.milestones = milestones
554555
self._state_attrs += ["initial_value", "gamma", "milestones"]
555556

556-
def get_param(self) -> Union[List[float], float]:
557+
def get_param(self) -> list[float] | float:
557558
return self.initial_value * self.gamma ** bisect_right(self.milestones, self.event_index)

0 commit comments

Comments
 (0)