diff --git a/src/anemoi/utils/registry.py b/src/anemoi/utils/registry.py index 6a7bbf6d..0502c8c0 100644 --- a/src/anemoi/utils/registry.py +++ b/src/anemoi/utils/registry.py @@ -19,6 +19,7 @@ from typing import Generic from typing import Optional from typing import TypeVar +from typing import overload import entrypoints @@ -26,8 +27,10 @@ DEBUG_ANEMOI_REGISTRY = int(os.environ.get("DEBUG_ANEMOI_REGISTRY", "0")) +T = TypeVar("T", bound=Callable[..., Any]) -class Wrapper: + +class Wrapper(Generic[T]): """A wrapper for the registry. Parameters @@ -42,7 +45,7 @@ def __init__(self, name: str, registry: "Registry"): self.name = name self.registry = registry - def __call__(self, factory: Callable) -> Callable: + def __call__(self, factory: T) -> T: """Register a factory with the registry. Parameters @@ -120,6 +123,15 @@ def lookup_kind(cls, kind: str) -> Optional["Registry"]: """ return _BY_KIND.get(kind) + @overload + def register( + self, name: str, factory: Callable[..., T], source: Any | None = None, aliases: list[str] | None = None + ) -> None: ... + @overload + def register( + self, name: str, factory: None = None, source: Any | None = None, aliases: list[str] | None = None + ) -> Wrapper: ... + def register( self, name: str, factory: Callable | None = None, source: Any | None = None, aliases: list[str] | None = None ) -> Wrapper | None: