diff --git a/redis/_parsers/helpers.py b/redis/_parsers/helpers.py index b6c3feb877..5af5b8604d 100644 --- a/redis/_parsers/helpers.py +++ b/redis/_parsers/helpers.py @@ -754,7 +754,7 @@ def string_keys_to_dict(key_string, callback): _RedisCallbacks = { **string_keys_to_dict( "AUTH COPY EXPIRE EXPIREAT HEXISTS HMSET MOVE MSETNX PERSIST PSETEX " - "PEXPIRE PEXPIREAT RENAMENX SETEX SETNX SMOVE", + "PEXPIRE PEXPIREAT RENAMENX SETEX SETNX SMOVE HSETNX SISMEMBER", bool, ), **string_keys_to_dict("HINCRBYFLOAT INCRBYFLOAT", float), @@ -803,6 +803,7 @@ def string_keys_to_dict(key_string, callback): "FUNCTION DELETE": bool_ok, "FUNCTION FLUSH": bool_ok, "FUNCTION RESTORE": bool_ok, + "RESTORE": bool_ok, "GEODIST": float_or_none, "HSCAN": parse_hscan, "INFO": parse_info, @@ -830,6 +831,7 @@ def string_keys_to_dict(key_string, callback): "SENTINEL SET": bool_ok, "SLOWLOG GET": parse_slowlog_get, "SLOWLOG RESET": bool_ok, + "SMISMEMBER": lambda r: list(map(bool, r)), "SORT": sort_return_tuples, "SSCAN": parse_scan, "TIME": lambda x: (int(x[0]), int(x[1])), @@ -887,6 +889,7 @@ def string_keys_to_dict(key_string, callback): "SENTINEL MASTERS": parse_sentinel_masters, "SENTINEL SENTINELS": parse_sentinel_slaves_and_sentinels, "SENTINEL SLAVES": parse_sentinel_slaves_and_sentinels, + "SMISMEMBER": lambda r: list(map(bool, r)), "STRALGO": parse_stralgo, "XINFO CONSUMERS": parse_list_of_dicts, "XINFO GROUPS": parse_list_of_dicts, @@ -932,6 +935,7 @@ def string_keys_to_dict(key_string, callback): "SENTINEL MASTERS": parse_sentinel_masters_resp3, "SENTINEL SENTINELS": parse_sentinel_slaves_and_sentinels_resp3, "SENTINEL SLAVES": parse_sentinel_slaves_and_sentinels_resp3, + "SMISMEMBER": lambda r: list(map(bool, r)), "STRALGO": lambda r, **options: ( {str_if_bytes(key): str_if_bytes(value) for key, value in r.items()} if isinstance(r, dict) diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index abc35ad225..86d296497d 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -12,6 +12,7 @@ Dict, Iterable, List, + Literal, Mapping, MutableMapping, Optional, @@ -23,6 +24,7 @@ TypeVar, Union, cast, + overload, ) from redis._parsers.helpers import ( @@ -69,7 +71,18 @@ ResponseError, WatchError, ) -from redis.typing import ChannelT, EncodableT, KeyT +from redis.typing import ( + ChannelT, + EncodableT, + KeyT, + ResponseTypeAnyString, + ResponseTypeListOfAnyOptionalStrings, + ResponseTypeListOfAnyStrings, + ResponseTypeLPopRPop, + ResponseTypeOptionalAnyString, + ResponseTypeOptionalListOfAnyStrings, + ResponseTypeOptionalLMPop, +) from redis.utils import ( SSL_AVAILABLE, _set_info_logger, @@ -87,10 +100,23 @@ VerifyMode = None VerifyFlags = None +if TYPE_CHECKING: + import sys + + if sys.version_info < (3, 11): + from typing_extensions import Self + else: + from typing import Self + + from redis.asyncio.typing import ( + RedisDecoded, + RedisEncoded, + RedisEncodedOrDecoded, + ) + PubSubHandler = Callable[[Dict[str, str]], Awaitable[None]] _KeyT = TypeVar("_KeyT", bound=KeyT) _ArgT = TypeVar("_ArgT", KeyT, EncodableT) -_RedisT = TypeVar("_RedisT", bound="Redis") _NormalizeKeysT = TypeVar("_NormalizeKeysT", bound=Mapping[ChannelT, object]) if TYPE_CHECKING: from redis.commands.core import Script @@ -108,7 +134,18 @@ async def __call__(self, response: Any, **kwargs): ... class Redis( - AbstractRedis, AsyncRedisModuleCommands, AsyncCoreCommands, AsyncSentinelCommands + AbstractRedis, + AsyncRedisModuleCommands, + AsyncCoreCommands[ + ResponseTypeAnyString, + ResponseTypeOptionalAnyString, + ResponseTypeListOfAnyStrings, + ResponseTypeListOfAnyOptionalStrings, + ResponseTypeOptionalListOfAnyStrings, + ResponseTypeOptionalLMPop, + ResponseTypeLPopRPop, + ], + AsyncSentinelCommands, ): """ Implementation of the Redis protocol. @@ -124,14 +161,66 @@ class Redis( response_callbacks: MutableMapping[Union[str, bytes], ResponseCallbackT] + # Overload for decode_responses=True + @overload + @classmethod + def from_url( + cls, + url: str, + single_connection_client: bool = False, + auto_close_connection_pool: Optional[bool] = None, + *, + decode_responses: Literal[True], + **kwargs: Any, + ) -> "RedisDecoded": ... + + # Overload for decode_responses=False + @overload @classmethod def from_url( - cls: Type["Redis"], + cls, url: str, single_connection_client: bool = False, auto_close_connection_pool: Optional[bool] = None, - **kwargs, - ) -> "Redis": + *, + decode_responses: Literal[False], + **kwargs: Any, + ) -> "RedisEncoded": ... + + # Overload for decode_responses passed as bool + @overload + @classmethod + def from_url( + cls, + url: str, + single_connection_client: bool = False, + auto_close_connection_pool: Optional[bool] = None, + *, + decode_responses: bool, + **kwargs: Any, + ) -> "RedisEncodedOrDecoded": ... + + # Overload for no decode_responses passed - by default False + @overload + @classmethod + def from_url( + cls, + url: str, + single_connection_client: bool = False, + auto_close_connection_pool: Optional[bool] = None, + **kwargs: Any, + ) -> "RedisEncoded": ... + + @classmethod + def from_url( + cls: Type["RedisEncoded"] + | Type["RedisDecoded"] + | Type["RedisEncodedOrDecoded"], + url: str, + single_connection_client: bool = False, + auto_close_connection_pool: Optional[bool] = None, + **kwargs: Any, + ) -> "RedisEncoded | RedisDecoded | RedisEncodedOrDecoded": """ Return a Redis client object configured from the given URL @@ -195,9 +284,9 @@ class initializer. In the case of conflicting arguments, querystring @classmethod def from_pool( - cls: Type["Redis"], + cls, connection_pool: ConnectionPool, - ) -> "Redis": + ) -> "RedisEncoded | RedisDecoded | RedisEncodedOrDecoded": """ Return a Redis client from the given connection pool. The Redis client will take ownership of the connection pool and @@ -209,6 +298,156 @@ def from_pool( client.auto_close_connection_pool = True return client + # Overload for decode_responses=True + @overload + def __init__( + self: "RedisDecoded", + *, + host: str = "localhost", + port: int = 6379, + db: Union[str, int] = 0, + password: Optional[str] = None, + socket_timeout: Optional[float] = None, + socket_connect_timeout: Optional[float] = None, + socket_keepalive: Optional[bool] = None, + socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None, + connection_pool: Optional[ConnectionPool] = None, + unix_socket_path: Optional[str] = None, + encoding: str = "utf-8", + encoding_errors: str = "strict", + decode_responses: Literal[True], + retry: Retry = Retry( + backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3 + ), + retry_on_error: Optional[list[Type[Exception]]] = None, + ssl: bool = False, + ssl_keyfile: Optional[str] = None, + ssl_certfile: Optional[str] = None, + ssl_cert_reqs: Union[str, VerifyMode] = "required", + ssl_include_verify_flags: Optional[List[VerifyFlags]] = None, + ssl_exclude_verify_flags: Optional[List[VerifyFlags]] = None, + ssl_ca_certs: Optional[str] = None, + ssl_ca_data: Optional[str] = None, + ssl_ca_path: Optional[str] = None, + ssl_check_hostname: bool = True, + ssl_min_version: Optional[TLSVersion] = None, + ssl_ciphers: Optional[str] = None, + ssl_password: Optional[str] = None, + max_connections: Optional[int] = None, + single_connection_client: bool = False, + health_check_interval: int = 0, + client_name: Optional[str] = None, + lib_name: Optional[str] = None, + lib_version: Optional[str] = None, + driver_info: Optional["DriverInfo"] = None, + username: Optional[str] = None, + auto_close_connection_pool: Optional[bool] = None, + redis_connect_func=None, + credential_provider: Optional[CredentialProvider] = None, + protocol: Optional[int] = 2, + event_dispatcher: Optional[EventDispatcher] = None, + ) -> None: ... + + # Default case - decode_responses=False + @overload + def __init__( + self: "RedisEncoded", + *, + host: str = "localhost", + port: int = 6379, + db: Union[str, int] = 0, + password: Optional[str] = None, + socket_timeout: Optional[float] = None, + socket_connect_timeout: Optional[float] = None, + socket_keepalive: Optional[bool] = None, + socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None, + connection_pool: Optional[ConnectionPool] = None, + unix_socket_path: Optional[str] = None, + encoding: str = "utf-8", + encoding_errors: str = "strict", + decode_responses: Literal[False] = False, + retry: Retry = Retry( + backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3 + ), + retry_on_error: Optional[list[Type[Exception]]] = None, + ssl: bool = False, + ssl_keyfile: Optional[str] = None, + ssl_certfile: Optional[str] = None, + ssl_cert_reqs: Union[str, VerifyMode] = "required", + ssl_include_verify_flags: Optional[List[VerifyFlags]] = None, + ssl_exclude_verify_flags: Optional[List[VerifyFlags]] = None, + ssl_ca_certs: Optional[str] = None, + ssl_ca_data: Optional[str] = None, + ssl_ca_path: Optional[str] = None, + ssl_check_hostname: bool = True, + ssl_min_version: Optional[TLSVersion] = None, + ssl_ciphers: Optional[str] = None, + ssl_password: Optional[str] = None, + max_connections: Optional[int] = None, + single_connection_client: bool = False, + health_check_interval: int = 0, + client_name: Optional[str] = None, + lib_name: Optional[str] = None, + lib_version: Optional[str] = None, + driver_info: Optional["DriverInfo"] = None, + username: Optional[str] = None, + auto_close_connection_pool: Optional[bool] = None, + redis_connect_func=None, + credential_provider: Optional[CredentialProvider] = None, + protocol: Optional[int] = 2, + event_dispatcher: Optional[EventDispatcher] = None, + ) -> None: ... + + # Runtime bool + @overload + def __init__( + self: "RedisEncodedOrDecoded", + *, + host: str = "localhost", + port: int = 6379, + db: Union[str, int] = 0, + password: Optional[str] = None, + socket_timeout: Optional[float] = None, + socket_connect_timeout: Optional[float] = None, + socket_keepalive: Optional[bool] = None, + socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None, + connection_pool: Optional[ConnectionPool] = None, + unix_socket_path: Optional[str] = None, + encoding: str = "utf-8", + encoding_errors: str = "strict", + decode_responses: bool = False, + retry: Retry = Retry( + backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3 + ), + retry_on_error: Optional[list[Type[Exception]]] = None, + ssl: bool = False, + ssl_keyfile: Optional[str] = None, + ssl_certfile: Optional[str] = None, + ssl_cert_reqs: Union[str, VerifyMode] = "required", + ssl_include_verify_flags: Optional[List[VerifyFlags]] = None, + ssl_exclude_verify_flags: Optional[List[VerifyFlags]] = None, + ssl_ca_certs: Optional[str] = None, + ssl_ca_data: Optional[str] = None, + ssl_ca_path: Optional[str] = None, + ssl_check_hostname: bool = True, + ssl_min_version: Optional[TLSVersion] = None, + ssl_ciphers: Optional[str] = None, + ssl_password: Optional[str] = None, + max_connections: Optional[int] = None, + single_connection_client: bool = False, + health_check_interval: int = 0, + client_name: Optional[str] = None, + lib_name: Optional[str] = None, + lib_version: Optional[str] = None, + driver_info: Optional["DriverInfo"] = None, + username: Optional[str] = None, + auto_close_connection_pool: Optional[bool] = None, + redis_connect_func=None, + credential_provider: Optional[CredentialProvider] = None, + protocol: Optional[int] = 2, + event_dispatcher: Optional[EventDispatcher] = None, + ) -> None: ... + @deprecated_args( args_to_warn=["retry_on_timeout"], reason="TimeoutError is included by default.", @@ -239,7 +478,7 @@ def __init__( retry: Retry = Retry( backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3 ), - retry_on_error: Optional[list] = None, + retry_on_error: Optional[list[Type[Exception]]] = None, ssl: bool = False, ssl_keyfile: Optional[str] = None, ssl_certfile: Optional[str] = None, @@ -422,7 +661,7 @@ def __repr__(self): def __await__(self): return self.initialize().__await__() - async def initialize(self: _RedisT) -> _RedisT: + async def initialize(self) -> "Self": if self.single_connection_client: async with self._single_conn_lock: if self.connection is None: @@ -619,7 +858,7 @@ def client(self) -> "Redis": connection_pool=self.connection_pool, single_connection_client=True ) - async def __aenter__(self: _RedisT) -> _RedisT: + async def __aenter__(self) -> "Self": """ Async context manager entry. Increments a usage counter so that the connection pool is only closed (via aclose()) when no context is using @@ -1359,7 +1598,7 @@ def __init__( self.scripts: Set[Script] = set() self.explicit_transaction = False - async def __aenter__(self: _RedisT) -> _RedisT: + async def __aenter__(self) -> "Self": return self async def __aexit__(self, exc_type, exc_value, traceback): diff --git a/redis/asyncio/typing.py b/redis/asyncio/typing.py new file mode 100644 index 0000000000..e839282024 --- /dev/null +++ b/redis/asyncio/typing.py @@ -0,0 +1,55 @@ +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from redis.asyncio.client import Redis + from redis.typing import ( + AnyStringType, + DecodedStringType, + EncodedStringType, + ListOfAnyOptionalStringsType, + ListOfAnyStringsType, + ListOfDecodedStringsType, + ListOfEncodedStringsType, + ListOfOptionalDecodedStringsType, + ListOfOptionalEncodedStringsType, + LMPopAnyReturnType, + LMPopDecodedReturnType, + LMPopEncodedReturnType, + LPopRPopAnyReturnType, + LPopRPopDecodedReturnType, + LPopRPopEncodedReturnType, + OptionalAnyStringType, + OptionalDecodedStringType, + OptionalEncodedStringType, + OptionalListOfAnyStringsType, + OptionalListOfDecodedStringsType, + OptionalListOfEncodedStringsType, + ) + + RedisEncoded = Redis[ + EncodedStringType, + OptionalEncodedStringType, + ListOfEncodedStringsType, + ListOfOptionalEncodedStringsType, + OptionalListOfEncodedStringsType, + LMPopEncodedReturnType, + LPopRPopEncodedReturnType, + ] + RedisDecoded = Redis[ + DecodedStringType, + OptionalDecodedStringType, + ListOfDecodedStringsType, + ListOfOptionalDecodedStringsType, + OptionalListOfDecodedStringsType, + LMPopDecodedReturnType, + LPopRPopDecodedReturnType, + ] + RedisEncodedOrDecoded = Redis[ + AnyStringType, + OptionalAnyStringType, + ListOfAnyStringsType, + ListOfAnyOptionalStringsType, + OptionalListOfAnyStringsType, + LMPopAnyReturnType, + LPopRPopAnyReturnType, + ] diff --git a/redis/client.py b/redis/client.py index d0e4ee7323..f3559ceeb6 100755 --- a/redis/client.py +++ b/redis/client.py @@ -9,11 +9,13 @@ Callable, Dict, List, + Literal, Mapping, Optional, Set, Type, Union, + overload, ) from redis._parsers.encoders import Encoder @@ -61,6 +63,15 @@ MaintNotificationsConfig, ) from redis.retry import Retry +from redis.typing import ( + ResponseTypeAnyString, + ResponseTypeListOfAnyOptionalStrings, + ResponseTypeListOfAnyStrings, + ResponseTypeLPopRPop, + ResponseTypeOptionalAnyString, + ResponseTypeOptionalListOfAnyStrings, + ResponseTypeOptionalLMPop, +) from redis.utils import ( _set_info_logger, deprecated_args, @@ -74,6 +85,12 @@ import OpenSSL + from redis.typing import ( + RedisDecoded, + RedisEncoded, + RedisEncodedOrDecoded, + ) + SYM_EMPTY = b"" EMPTY_RESPONSE = "EMPTY_RESPONSE" @@ -112,7 +129,19 @@ class AbstractRedis: pass -class Redis(RedisModuleCommands, CoreCommands, SentinelCommands): +class Redis( + RedisModuleCommands, + CoreCommands[ + ResponseTypeAnyString, + ResponseTypeOptionalAnyString, + ResponseTypeListOfAnyStrings, + ResponseTypeListOfAnyOptionalStrings, + ResponseTypeOptionalListOfAnyStrings, + ResponseTypeOptionalLMPop, + ResponseTypeLPopRPop, + ], + SentinelCommands, +): """ Implementation of the Redis protocol. @@ -127,8 +156,56 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands): It is not safe to pass PubSub or Pipeline objects between threads. """ + # Overload for decode_responses=True + @overload + @classmethod + def from_url( + cls, + url: str, + *, + decode_responses: Literal[True], + **kwargs: Any, + ) -> "RedisDecoded": ... + + # Overload for decode_responses=False + @overload @classmethod - def from_url(cls, url: str, **kwargs) -> "Redis": + def from_url( + cls, + url: str, + *, + decode_responses: Literal[False], + **kwargs: Any, + ) -> "RedisEncoded": ... + + # Runtime bool decode_responses=some_bool_var + @overload + @classmethod + def from_url( + cls, + url: str, + *, + decode_responses: bool, + **kwargs: Any, + ) -> "RedisEncodedOrDecoded": ... + + # Overload for no decode_responses passed - by default False + @overload + @classmethod + def from_url( + cls, + url: str, + **kwargs: Any, + ) -> "RedisEncoded": ... + + @classmethod + def from_url( + cls: Type["RedisEncoded"] + | Type["RedisDecoded"] + | Type["RedisEncodedOrDecoded"], + url: str, + **kwargs: Any, + ) -> "RedisEncoded | RedisDecoded | RedisEncodedOrDecoded": """ Return a Redis client object configured from the given URL @@ -180,9 +257,9 @@ class initializer. In the case of conflicting arguments, querystring @classmethod def from_pool( - cls: Type["Redis"], + cls, connection_pool: ConnectionPool, - ) -> "Redis": + ) -> "RedisEncoded | RedisDecoded | RedisEncodedOrDecoded": """ Return a Redis client from the given connection pool. The Redis client will take ownership of the connection pool and @@ -194,6 +271,174 @@ def from_pool( client.auto_close_connection_pool = True return client + # Default case - decode_responses=False + @overload + def __init__( + self: "RedisEncoded", + host: str = "localhost", + port: int = 6379, + db: int = 0, + password: Optional[str] = None, + socket_timeout: Optional[float] = None, + socket_connect_timeout: Optional[float] = None, + socket_keepalive: Optional[bool] = None, + socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None, + connection_pool: Optional[ConnectionPool] = None, + unix_socket_path: Optional[str] = None, + encoding: str = "utf-8", + encoding_errors: str = "strict", + decode_responses: Literal[False] = False, + retry_on_timeout: bool = False, + retry: Retry = Retry( + backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3 + ), + retry_on_error: Optional[List[Type[Exception]]] = None, + ssl: bool = False, + ssl_keyfile: Optional[str] = None, + ssl_certfile: Optional[str] = None, + ssl_cert_reqs: Union[str, "ssl.VerifyMode"] = "required", + ssl_include_verify_flags: Optional[List["ssl.VerifyFlags"]] = None, + ssl_exclude_verify_flags: Optional[List["ssl.VerifyFlags"]] = None, + ssl_ca_certs: Optional[str] = None, + ssl_ca_path: Optional[str] = None, + ssl_ca_data: Optional[str] = None, + ssl_check_hostname: bool = True, + ssl_password: Optional[str] = None, + ssl_validate_ocsp: bool = False, + ssl_validate_ocsp_stapled: bool = False, + ssl_ocsp_context: Optional["OpenSSL.SSL.Context"] = None, + ssl_ocsp_expected_cert: Optional[str] = None, + ssl_min_version: Optional["ssl.TLSVersion"] = None, + ssl_ciphers: Optional[str] = None, + max_connections: Optional[int] = None, + single_connection_client: bool = False, + health_check_interval: int = 0, + client_name: Optional[str] = None, + lib_name: Optional[str] = None, + lib_version: Optional[str] = None, + driver_info: Optional["DriverInfo"] = None, + username: Optional[str] = None, + redis_connect_func: Optional[Callable[[], None]] = None, + credential_provider: Optional[CredentialProvider] = None, + protocol: Optional[int] = 2, + cache: Optional[CacheInterface] = None, + cache_config: Optional[CacheConfig] = None, + event_dispatcher: Optional[EventDispatcher] = None, + maint_notifications_config: Optional[MaintNotificationsConfig] = None, + ) -> None: ... + + # Overload for decode_responses=True + @overload + def __init__( + self: "RedisDecoded", + host: str = "localhost", + port: int = 6379, + db: int = 0, + password: Optional[str] = None, + socket_timeout: Optional[float] = None, + socket_connect_timeout: Optional[float] = None, + socket_keepalive: Optional[bool] = None, + socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None, + connection_pool: Optional[ConnectionPool] = None, + unix_socket_path: Optional[str] = None, + encoding: str = "utf-8", + encoding_errors: str = "strict", + decode_responses: Literal[True] = ..., + retry_on_timeout: bool = False, + retry: Retry = Retry( + backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3 + ), + retry_on_error: Optional[List[Type[Exception]]] = None, + ssl: bool = False, + ssl_keyfile: Optional[str] = None, + ssl_certfile: Optional[str] = None, + ssl_cert_reqs: Union[str, "ssl.VerifyMode"] = "required", + ssl_include_verify_flags: Optional[List["ssl.VerifyFlags"]] = None, + ssl_exclude_verify_flags: Optional[List["ssl.VerifyFlags"]] = None, + ssl_ca_certs: Optional[str] = None, + ssl_ca_path: Optional[str] = None, + ssl_ca_data: Optional[str] = None, + ssl_check_hostname: bool = True, + ssl_password: Optional[str] = None, + ssl_validate_ocsp: bool = False, + ssl_validate_ocsp_stapled: bool = False, + ssl_ocsp_context: Optional["OpenSSL.SSL.Context"] = None, + ssl_ocsp_expected_cert: Optional[str] = None, + ssl_min_version: Optional["ssl.TLSVersion"] = None, + ssl_ciphers: Optional[str] = None, + max_connections: Optional[int] = None, + single_connection_client: bool = False, + health_check_interval: int = 0, + client_name: Optional[str] = None, + lib_name: Optional[str] = None, + lib_version: Optional[str] = None, + driver_info: Optional["DriverInfo"] = None, + username: Optional[str] = None, + redis_connect_func: Optional[Callable[[], None]] = None, + credential_provider: Optional[CredentialProvider] = None, + protocol: Optional[int] = 2, + cache: Optional[CacheInterface] = None, + cache_config: Optional[CacheConfig] = None, + event_dispatcher: Optional[EventDispatcher] = None, + maint_notifications_config: Optional[MaintNotificationsConfig] = None, + ) -> None: ... + + # Runtime bool decode_responses=some_bool_var + @overload + def __init__( + self: "RedisEncodedOrDecoded", + host: str = "localhost", + port: int = 6379, + db: int = 0, + password: Optional[str] = None, + socket_timeout: Optional[float] = None, + socket_connect_timeout: Optional[float] = None, + socket_keepalive: Optional[bool] = None, + socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None, + connection_pool: Optional[ConnectionPool] = None, + unix_socket_path: Optional[str] = None, + encoding: str = "utf-8", + encoding_errors: str = "strict", + decode_responses: bool = False, + retry_on_timeout: bool = False, + retry: Retry = Retry( + backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3 + ), + retry_on_error: Optional[List[Type[Exception]]] = None, + ssl: bool = False, + ssl_keyfile: Optional[str] = None, + ssl_certfile: Optional[str] = None, + ssl_cert_reqs: Union[str, "ssl.VerifyMode"] = "required", + ssl_include_verify_flags: Optional[List["ssl.VerifyFlags"]] = None, + ssl_exclude_verify_flags: Optional[List["ssl.VerifyFlags"]] = None, + ssl_ca_certs: Optional[str] = None, + ssl_ca_path: Optional[str] = None, + ssl_ca_data: Optional[str] = None, + ssl_check_hostname: bool = True, + ssl_password: Optional[str] = None, + ssl_validate_ocsp: bool = False, + ssl_validate_ocsp_stapled: bool = False, + ssl_ocsp_context: Optional["OpenSSL.SSL.Context"] = None, + ssl_ocsp_expected_cert: Optional[str] = None, + ssl_min_version: Optional["ssl.TLSVersion"] = None, + ssl_ciphers: Optional[str] = None, + max_connections: Optional[int] = None, + single_connection_client: bool = False, + health_check_interval: int = 0, + client_name: Optional[str] = None, + lib_name: Optional[str] = None, + lib_version: Optional[str] = None, + driver_info: Optional["DriverInfo"] = None, + username: Optional[str] = None, + redis_connect_func: Optional[Callable[[], None]] = None, + credential_provider: Optional[CredentialProvider] = None, + protocol: Optional[int] = 2, + cache: Optional[CacheInterface] = None, + cache_config: Optional[CacheConfig] = None, + event_dispatcher: Optional[EventDispatcher] = None, + maint_notifications_config: Optional[MaintNotificationsConfig] = None, + ) -> None: ... + @deprecated_args( args_to_warn=["retry_on_timeout"], reason="TimeoutError is included by default.", diff --git a/redis/commands/core.py b/redis/commands/core.py index 525b31c99d..ba419b8646 100644 --- a/redis/commands/core.py +++ b/redis/commands/core.py @@ -1,5 +1,3 @@ -# from __future__ import annotations - import datetime import hashlib import warnings @@ -11,6 +9,7 @@ Awaitable, Callable, Dict, + Generic, Iterable, Iterator, List, @@ -28,19 +27,38 @@ AbsExpiryT, AnyKeyT, BitfieldOffsetT, + BooleanType, ChannelT, CommandsProtocol, ConsumerT, EncodableT, ExpiryT, FieldT, + FloatType, GroupT, + IntegerType, KeysT, KeyT, Number, + OptionalEncodedStringType, + OptionalIntegerType, PatternT, ResponseT, + ResponseTypeAnyString, + ResponseTypeBoolean, + ResponseTypeFloat, + ResponseTypeInteger, + ResponseTypeListOfAnyOptionalStrings, + ResponseTypeListOfAnyStrings, + ResponseTypeLPopRPop, + ResponseTypeOptionalAnyString, + ResponseTypeOptionalEncodedString, + ResponseTypeOptionalInteger, + ResponseTypeOptionalListOfAnyStrings, + ResponseTypeOptionalLMPop, + ResponseTypeStrAlgoResult, ScriptTextT, + StrAlgoResultType, StreamIdT, TimeoutSecT, ZScoreBoundT, @@ -1571,12 +1589,25 @@ class DataPersistOptions(Enum): XX = "XX" -class BasicKeyCommands(CommandsProtocol): +class BasicKeyCommands( + CommandsProtocol, + Generic[ + ResponseTypeBoolean, + ResponseTypeFloat, + ResponseTypeInteger, + ResponseTypeOptionalEncodedString, + ResponseTypeStrAlgoResult, + ResponseTypeAnyString, + ResponseTypeOptionalAnyString, + ResponseTypeListOfAnyStrings, + ResponseTypeListOfAnyOptionalStrings, + ], +): """ Redis basic key-based commands """ - def append(self, key: KeyT, value: EncodableT) -> ResponseT: + def append(self, key: KeyT, value: EncodableT) -> ResponseTypeInteger: """ Appends the string ``value`` to the value at ``key``. If ``key`` doesn't already exist, create it with a value of ``value``. @@ -1592,7 +1623,7 @@ def bitcount( start: Optional[int] = None, end: Optional[int] = None, mode: Optional[str] = None, - ) -> ResponseT: + ) -> ResponseTypeInteger: """ Returns the count of set bits in the value of ``key``. Optional ``start`` and ``end`` parameters indicate which bytes to consider @@ -1628,7 +1659,7 @@ def bitfield_ro( encoding: str, offset: BitfieldOffsetT, items: Optional[list] = None, - ) -> ResponseT: + ) -> ResponseTypeInteger: """ Return an array of the specified bitfield values where the first value is found using ``encoding`` and ``offset`` @@ -1645,7 +1676,7 @@ def bitfield_ro( params.extend(["GET", encoding, offset]) return self.execute_command("BITFIELD_RO", *params, keys=[key]) - def bitop(self, operation: str, dest: KeyT, *keys: KeyT) -> ResponseT: + def bitop(self, operation: str, dest: KeyT, *keys: KeyT) -> ResponseTypeInteger: """ Perform a bitwise operation using ``operation`` between ``keys`` and store the result in ``dest``. @@ -1661,7 +1692,7 @@ def bitpos( start: Optional[int] = None, end: Optional[int] = None, mode: Optional[str] = None, - ) -> ResponseT: + ) -> ResponseTypeInteger: """ Return the position of the first bit set to 1 or 0 in a string. ``start`` and ``end`` defines search range. The range is interpreted @@ -1691,7 +1722,7 @@ def copy( destination: str, destination_db: Optional[str] = None, replace: bool = False, - ) -> ResponseT: + ) -> ResponseTypeBoolean: """ Copy the value stored in the ``source`` key to the ``destination`` key. @@ -1711,7 +1742,7 @@ def copy( params.append("REPLACE") return self.execute_command("COPY", *params) - def decrby(self, name: KeyT, amount: int = 1) -> ResponseT: + def decrby(self, name: KeyT, amount: int = 1) -> ResponseTypeInteger: """ Decrements the value of ``key`` by ``amount``. If no key exists, the value will be initialized as 0 - ``amount`` @@ -1722,7 +1753,7 @@ def decrby(self, name: KeyT, amount: int = 1) -> ResponseT: decr = decrby - def delete(self, *names: KeyT) -> ResponseT: + def delete(self, *names: KeyT) -> ResponseTypeInteger: """ Delete one or more keys specified by ``names`` """ @@ -1739,7 +1770,7 @@ def delex( ifne: Optional[Union[bytes, str]] = None, ifdeq: Optional[str] = None, # hex digest ifdne: Optional[str] = None, # hex digest - ) -> int: + ) -> ResponseTypeInteger: """ Conditionally removes the specified key. @@ -1782,7 +1813,7 @@ def delex( return self.execute_command(*pieces) - def dump(self, name: KeyT) -> ResponseT: + def dump(self, name: KeyT) -> ResponseTypeOptionalEncodedString: """ Return a serialized version of the value stored at the specified key. If key does not exist a nil bulk reply is returned. @@ -1795,7 +1826,7 @@ def dump(self, name: KeyT) -> ResponseT: options[NEVER_DECODE] = [] return self.execute_command("DUMP", name, **options) - def exists(self, *names: KeyT) -> ResponseT: + def exists(self, *names: KeyT) -> ResponseTypeInteger: """ Returns the number of ``names`` that exist @@ -1813,7 +1844,7 @@ def expire( xx: bool = False, gt: bool = False, lt: bool = False, - ) -> ResponseT: + ) -> ResponseTypeBoolean: """ Set an expire flag on key ``name`` for ``time`` seconds with given ``option``. ``time`` can be represented by an integer or a Python timedelta @@ -1850,7 +1881,7 @@ def expireat( xx: bool = False, gt: bool = False, lt: bool = False, - ) -> ResponseT: + ) -> ResponseTypeBoolean: """ Set an expire flag on key ``name`` with given ``option``. ``when`` can be represented as an integer indicating unix time or a Python @@ -1879,7 +1910,7 @@ def expireat( return self.execute_command("EXPIREAT", name, when, *exp_option) - def expiretime(self, key: str) -> int: + def expiretime(self, key: str) -> ResponseTypeInteger: """ Returns the absolute Unix timestamp (since January 1, 1970) in seconds at which the given key will expire. @@ -1889,7 +1920,7 @@ def expiretime(self, key: str) -> int: return self.execute_command("EXPIRETIME", key) @experimental_method() - def digest(self, name: KeyT) -> Optional[str]: + def digest(self, name: KeyT) -> ResponseTypeOptionalAnyString: """ Return the digest of the value stored at the specified key. @@ -1914,7 +1945,7 @@ def digest(self, name: KeyT) -> Optional[str]: # Bulk string response is already handled (bytes/str based on decode_responses) return self.execute_command("DIGEST", name) - def get(self, name: KeyT) -> ResponseT: + def get(self, name: KeyT) -> ResponseTypeOptionalAnyString: """ Return the value at key ``name``, or None if the key doesn't exist @@ -1922,7 +1953,7 @@ def get(self, name: KeyT) -> ResponseT: """ return self.execute_command("GET", name, keys=[name]) - def getdel(self, name: KeyT) -> ResponseT: + def getdel(self, name: KeyT) -> ResponseTypeOptionalAnyString: """ Get the value at key ``name`` and delete the key. This command is similar to GET, except for the fact that it also deletes @@ -1941,7 +1972,7 @@ def getex( exat: Optional[AbsExpiryT] = None, pxat: Optional[AbsExpiryT] = None, persist: bool = False, - ) -> ResponseT: + ) -> ResponseTypeOptionalAnyString: """ Get the value of key and optionally set its expiration. GETEX is similar to GET, but is a write command with @@ -1975,7 +2006,7 @@ def getex( return self.execute_command("GETEX", name, *exp_options) - def __getitem__(self, name: KeyT): + def __getitem__(self, name: KeyT) -> ResponseTypeAnyString: """ Return the value at key ``name``, raises a KeyError if the key doesn't exist. @@ -1985,7 +2016,7 @@ def __getitem__(self, name: KeyT): return value raise KeyError(name) - def getbit(self, name: KeyT, offset: int) -> ResponseT: + def getbit(self, name: KeyT, offset: int) -> ResponseTypeInteger: """ Returns an integer indicating the value of ``offset`` in ``name`` @@ -1993,7 +2024,7 @@ def getbit(self, name: KeyT, offset: int) -> ResponseT: """ return self.execute_command("GETBIT", name, offset, keys=[name]) - def getrange(self, key: KeyT, start: int, end: int) -> ResponseT: + def getrange(self, key: KeyT, start: int, end: int) -> ResponseTypeAnyString: """ Returns the substring of the string value stored at ``key``, determined by the offsets ``start`` and ``end`` (both are inclusive) @@ -2002,7 +2033,7 @@ def getrange(self, key: KeyT, start: int, end: int) -> ResponseT: """ return self.execute_command("GETRANGE", key, start, end, keys=[key]) - def getset(self, name: KeyT, value: EncodableT) -> ResponseT: + def getset(self, name: KeyT, value: EncodableT) -> ResponseTypeOptionalAnyString: """ Sets the value at key ``name`` to ``value`` and returns the old value at key ``name`` atomically. @@ -2014,7 +2045,7 @@ def getset(self, name: KeyT, value: EncodableT) -> ResponseT: """ return self.execute_command("GETSET", name, value) - def incrby(self, name: KeyT, amount: int = 1) -> ResponseT: + def incrby(self, name: KeyT, amount: int = 1) -> ResponseTypeInteger: """ Increments the value of ``key`` by ``amount``. If no key exists, the value will be initialized as ``amount`` @@ -2025,7 +2056,7 @@ def incrby(self, name: KeyT, amount: int = 1) -> ResponseT: incr = incrby - def incrbyfloat(self, name: KeyT, amount: float = 1.0) -> ResponseT: + def incrbyfloat(self, name: KeyT, amount: float = 1.0) -> ResponseTypeFloat: """ Increments the value at key ``name`` by floating ``amount``. If no key exists, the value will be initialized as ``amount`` @@ -2034,7 +2065,7 @@ def incrbyfloat(self, name: KeyT, amount: float = 1.0) -> ResponseT: """ return self.execute_command("INCRBYFLOAT", name, amount) - def keys(self, pattern: PatternT = "*", **kwargs) -> ResponseT: + def keys(self, pattern: PatternT = "*", **kwargs) -> ResponseTypeListOfAnyStrings: """ Returns a list of keys matching ``pattern`` @@ -2044,7 +2075,7 @@ def keys(self, pattern: PatternT = "*", **kwargs) -> ResponseT: def lmove( self, first_list: str, second_list: str, src: str = "LEFT", dest: str = "RIGHT" - ) -> ResponseT: + ) -> ResponseTypeAnyString: """ Atomically returns and removes the first/last element of a list, pushing it as the first/last element on the destination list. @@ -2062,7 +2093,7 @@ def blmove( timeout: int, src: str = "LEFT", dest: str = "RIGHT", - ) -> ResponseT: + ) -> ResponseTypeAnyString: """ Blocking version of lmove. @@ -2071,7 +2102,9 @@ def blmove( params = [first_list, second_list, src, dest, timeout] return self.execute_command("BLMOVE", *params) - def mget(self, keys: KeysT, *args: EncodableT) -> ResponseT: + def mget( + self, keys: KeysT, *args: EncodableT + ) -> ResponseTypeListOfAnyOptionalStrings: """ Returns a list of values ordered identically to ``keys`` @@ -2090,7 +2123,7 @@ def mget(self, keys: KeysT, *args: EncodableT) -> ResponseT: options["keys"] = args return self.execute_command("MGET", *args, **options) - def mset(self, mapping: Mapping[AnyKeyT, EncodableT]) -> ResponseT: + def mset(self, mapping: Mapping[AnyKeyT, EncodableT]) -> ResponseTypeBoolean: """ Sets key/values based on a mapping. Mapping is a dictionary of key/value pairs. Both keys and values should be strings or types that @@ -2116,7 +2149,7 @@ def msetex( exat: Optional[AbsExpiryT] = None, pxat: Optional[AbsExpiryT] = None, keepttl: bool = False, - ) -> Union[Awaitable[int], int]: + ) -> ResponseTypeInteger: """ Sets key/values based on the provided ``mapping`` items. @@ -2172,7 +2205,7 @@ def msetex( return self.execute_command(*pieces, *exp_options) - def msetnx(self, mapping: Mapping[AnyKeyT, EncodableT]) -> ResponseT: + def msetnx(self, mapping: Mapping[AnyKeyT, EncodableT]) -> ResponseTypeBoolean: """ Sets key/values based on a mapping if none of the keys are already set. Mapping is a dictionary of key/value pairs. Both keys and values @@ -2190,7 +2223,7 @@ def msetnx(self, mapping: Mapping[AnyKeyT, EncodableT]) -> ResponseT: items.extend(pair) return self.execute_command("MSETNX", *items) - def move(self, name: KeyT, db: int) -> ResponseT: + def move(self, name: KeyT, db: int) -> ResponseTypeBoolean: """ Moves the key ``name`` to a different Redis database ``db`` @@ -2198,7 +2231,7 @@ def move(self, name: KeyT, db: int) -> ResponseT: """ return self.execute_command("MOVE", name, db) - def persist(self, name: KeyT) -> ResponseT: + def persist(self, name: KeyT) -> ResponseTypeBoolean: """ Removes an expiration on ``name`` @@ -2214,7 +2247,7 @@ def pexpire( xx: bool = False, gt: bool = False, lt: bool = False, - ) -> ResponseT: + ) -> ResponseTypeBoolean: """ Set an expire flag on key ``name`` for ``time`` milliseconds with given ``option``. ``time`` can be represented by an @@ -2250,7 +2283,7 @@ def pexpireat( xx: bool = False, gt: bool = False, lt: bool = False, - ) -> ResponseT: + ) -> ResponseTypeBoolean: """ Set an expire flag on key ``name`` with given ``option``. ``when`` can be represented as an integer representing unix time in @@ -2277,7 +2310,7 @@ def pexpireat( exp_option.append("LT") return self.execute_command("PEXPIREAT", name, when, *exp_option) - def pexpiretime(self, key: str) -> int: + def pexpiretime(self, key: str) -> ResponseTypeInteger: """ Returns the absolute Unix timestamp (since January 1, 1970) in milliseconds at which the given key will expire. @@ -2286,7 +2319,9 @@ def pexpiretime(self, key: str) -> int: """ return self.execute_command("PEXPIRETIME", key) - def psetex(self, name: KeyT, time_ms: ExpiryT, value: EncodableT): + def psetex( + self, name: KeyT, time_ms: ExpiryT, value: EncodableT + ) -> ResponseTypeBoolean: """ Set the value of key ``name`` to ``value`` that expires in ``time_ms`` milliseconds. ``time_ms`` can be represented by an integer or a Python @@ -2298,7 +2333,7 @@ def psetex(self, name: KeyT, time_ms: ExpiryT, value: EncodableT): time_ms = int(time_ms.total_seconds() * 1000) return self.execute_command("PSETEX", name, time_ms, value) - def pttl(self, name: KeyT) -> ResponseT: + def pttl(self, name: KeyT) -> ResponseTypeInteger: """ Returns the number of milliseconds until the key ``name`` will expire @@ -2307,8 +2342,8 @@ def pttl(self, name: KeyT) -> ResponseT: return self.execute_command("PTTL", name) def hrandfield( - self, key: str, count: Optional[int] = None, withvalues: bool = False - ) -> ResponseT: + self, key: str, count: int = None, withvalues: bool = False + ) -> ResponseTypeListOfAnyStrings: """ Return a random field from the hash value stored at key. @@ -2330,7 +2365,7 @@ def hrandfield( return self.execute_command("HRANDFIELD", key, *params) - def randomkey(self, **kwargs) -> ResponseT: + def randomkey(self, **kwargs) -> ResponseTypeAnyString: """ Returns the name of a random key @@ -2338,7 +2373,7 @@ def randomkey(self, **kwargs) -> ResponseT: """ return self.execute_command("RANDOMKEY", **kwargs) - def rename(self, src: KeyT, dst: KeyT) -> ResponseT: + def rename(self, src: KeyT, dst: KeyT) -> ResponseTypeBoolean: """ Rename key ``src`` to ``dst`` @@ -2346,7 +2381,7 @@ def rename(self, src: KeyT, dst: KeyT) -> ResponseT: """ return self.execute_command("RENAME", src, dst) - def renamenx(self, src: KeyT, dst: KeyT): + def renamenx(self, src: KeyT, dst: KeyT) -> ResponseTypeBoolean: """ Rename key ``src`` to ``dst`` if ``dst`` doesn't already exist @@ -2363,7 +2398,7 @@ def restore( absttl: bool = False, idletime: Optional[int] = None, frequency: Optional[int] = None, - ) -> ResponseT: + ) -> ResponseTypeBoolean: """ Create a key using the provided serialized value, previously obtained using DUMP. @@ -2421,7 +2456,7 @@ def set( ifne: Optional[Union[bytes, str]] = None, ifdeq: Optional[str] = None, # hex digest of current value ifdne: Optional[str] = None, # hex digest of current value - ) -> ResponseT: + ) -> ResponseTypeBoolean: """ Set the value at key ``name`` to ``value`` @@ -2520,7 +2555,7 @@ def set( def __setitem__(self, name: KeyT, value: EncodableT): self.set(name, value) - def setbit(self, name: KeyT, offset: int, value: int) -> ResponseT: + def setbit(self, name: KeyT, offset: int, value: int) -> ResponseTypeInteger: """ Flag the ``offset`` in ``name`` as ``value``. Returns an integer indicating the previous value of ``offset``. @@ -2530,7 +2565,9 @@ def setbit(self, name: KeyT, offset: int, value: int) -> ResponseT: value = value and 1 or 0 return self.execute_command("SETBIT", name, offset, value) - def setex(self, name: KeyT, time: ExpiryT, value: EncodableT) -> ResponseT: + def setex( + self, name: KeyT, time: ExpiryT, value: EncodableT + ) -> ResponseTypeBoolean: """ Set the value of key ``name`` to ``value`` that expires in ``time`` seconds. ``time`` can be represented by an integer or a Python @@ -2542,7 +2579,7 @@ def setex(self, name: KeyT, time: ExpiryT, value: EncodableT) -> ResponseT: time = int(time.total_seconds()) return self.execute_command("SETEX", name, time, value) - def setnx(self, name: KeyT, value: EncodableT) -> ResponseT: + def setnx(self, name: KeyT, value: EncodableT) -> ResponseTypeBoolean: """ Set the value of key ``name`` to ``value`` if key doesn't exist @@ -2550,7 +2587,12 @@ def setnx(self, name: KeyT, value: EncodableT) -> ResponseT: """ return self.execute_command("SETNX", name, value) - def setrange(self, name: KeyT, offset: int, value: EncodableT) -> ResponseT: + def setrange( + self, + name: KeyT, + offset: int, + value: EncodableT, + ) -> ResponseTypeInteger: """ Overwrite bytes in the value of ``name`` starting at ``offset`` with ``value``. If ``offset`` plus the length of ``value`` exceeds the @@ -2576,7 +2618,7 @@ def stralgo( minmatchlen: Optional[int] = None, withmatchlen: bool = False, **kwargs, - ) -> ResponseT: + ) -> ResponseTypeStrAlgoResult: """ Implements complex algorithms that operate on strings. Right now the only algorithm implemented is the LCS algorithm @@ -2629,7 +2671,7 @@ def stralgo( **kwargs, ) - def strlen(self, name: KeyT) -> ResponseT: + def strlen(self, name: KeyT) -> ResponseTypeInteger: """ Return the number of bytes stored in the value of ``name`` @@ -2637,14 +2679,14 @@ def strlen(self, name: KeyT) -> ResponseT: """ return self.execute_command("STRLEN", name, keys=[name]) - def substr(self, name: KeyT, start: int, end: int = -1) -> ResponseT: + def substr(self, name: KeyT, start: int, end: int = -1) -> ResponseTypeAnyString: """ Return a substring of the string at key ``name``. ``start`` and ``end`` are 0-based integers specifying the portion of the string to return. """ return self.execute_command("SUBSTR", name, start, end, keys=[name]) - def touch(self, *args: KeyT) -> ResponseT: + def touch(self, *args: KeyT) -> ResponseTypeInteger: """ Alters the last access time of a key(s) ``*args``. A key is ignored if it does not exist. @@ -2653,7 +2695,7 @@ def touch(self, *args: KeyT) -> ResponseT: """ return self.execute_command("TOUCH", *args) - def ttl(self, name: KeyT) -> ResponseT: + def ttl(self, name: KeyT) -> ResponseTypeInteger: """ Returns the number of seconds until the key ``name`` will expire @@ -2661,7 +2703,7 @@ def ttl(self, name: KeyT) -> ResponseT: """ return self.execute_command("TTL", name) - def type(self, name: KeyT) -> ResponseT: + def type(self, name: KeyT) -> ResponseTypeAnyString: """ Returns the type of key ``name`` @@ -2685,7 +2727,7 @@ def unwatch(self) -> None: """ warnings.warn(DeprecationWarning("Call UNWATCH from a Pipeline object")) - def unlink(self, *names: KeyT) -> ResponseT: + def unlink(self, *names: KeyT) -> ResponseTypeInteger: """ Unlink one or more keys specified by ``names`` @@ -2701,7 +2743,7 @@ def lcs( idx: Optional[bool] = False, minmatchlen: Optional[int] = 0, withmatchlen: Optional[bool] = False, - ) -> Union[str, int, list]: + ) -> ResponseTypeAnyString: """ Find the longest common subsequence between ``key1`` and ``key2``. If ``len`` is true the length of the match will will be returned. @@ -2723,7 +2765,19 @@ def lcs( return self.execute_command("LCS", *pieces, keys=[key1, key2]) -class AsyncBasicKeyCommands(BasicKeyCommands): +class AsyncBasicKeyCommands( + BasicKeyCommands[ + ResponseTypeBoolean, + ResponseTypeFloat, + ResponseTypeInteger, + ResponseTypeOptionalEncodedString, + ResponseTypeStrAlgoResult, + ResponseTypeAnyString, + ResponseTypeOptionalAnyString, + ResponseTypeListOfAnyStrings, + ResponseTypeListOfAnyOptionalStrings, + ], +): def __delitem__(self, name: KeyT): raise TypeError("Async Redis client does not support class deletion") @@ -2743,7 +2797,20 @@ async def unwatch(self) -> None: return super().unwatch() -class ListCommands(CommandsProtocol): +class ListCommands( + CommandsProtocol, + Generic[ + ResponseTypeBoolean, + ResponseTypeInteger, + ResponseTypeOptionalInteger, + ResponseTypeAnyString, + ResponseTypeOptionalAnyString, + ResponseTypeListOfAnyStrings, + ResponseTypeOptionalListOfAnyStrings, + ResponseTypeOptionalLMPop, + ResponseTypeLPopRPop, + ], +): """ Redis commands for List data type. see: https://redis.io/topics/data-types#lists @@ -2751,7 +2818,7 @@ class ListCommands(CommandsProtocol): def blpop( self, keys: List, timeout: Optional[Number] = 0 - ) -> Union[Awaitable[list], list]: + ) -> ResponseTypeOptionalListOfAnyStrings: """ LPOP a value off of the first non-empty list named in the ``keys`` list. @@ -2772,7 +2839,7 @@ def blpop( def brpop( self, keys: List, timeout: Optional[Number] = 0 - ) -> Union[Awaitable[list], list]: + ) -> ResponseTypeOptionalListOfAnyStrings: """ RPOP a value off of the first non-empty list named in the ``keys`` list. @@ -2793,7 +2860,7 @@ def brpop( def brpoplpush( self, src: KeyT, dst: KeyT, timeout: Optional[Number] = 0 - ) -> Union[Awaitable[Optional[str]], Optional[str]]: + ) -> ResponseTypeOptionalAnyString: """ Pop a value off the tail of ``src``, push it on the head of ``dst`` and then return it. @@ -2815,7 +2882,7 @@ def blmpop( *args: str, direction: str, count: Optional[int] = 1, - ) -> Optional[list]: + ) -> ResponseTypeOptionalLMPop: """ Pop ``count`` values (default 1) from first non-empty in the list of provided key names. @@ -2835,7 +2902,7 @@ def lmpop( *args: str, direction: str, count: Optional[int] = 1, - ) -> Union[Awaitable[list], list]: + ) -> ResponseTypeOptionalLMPop: """ Pop ``count`` values (default 1) first non-empty list key from the list of args provided key names. @@ -2848,9 +2915,7 @@ def lmpop( return self.execute_command("LMPOP", *cmd_args) - def lindex( - self, name: KeyT, index: int - ) -> Union[Awaitable[Optional[str]], Optional[str]]: + def lindex(self, name: KeyT, index: int) -> ResponseTypeOptionalAnyString: """ Return the item from list ``name`` at position ``index`` @@ -2863,7 +2928,7 @@ def lindex( def linsert( self, name: KeyT, where: str, refvalue: str, value: str - ) -> Union[Awaitable[int], int]: + ) -> ResponseTypeInteger: """ Insert ``value`` in list ``name`` either immediately before or after [``where``] ``refvalue`` @@ -2875,7 +2940,7 @@ def linsert( """ return self.execute_command("LINSERT", name, where, refvalue, value) - def llen(self, name: KeyT) -> Union[Awaitable[int], int]: + def llen(self, name: KeyT) -> ResponseTypeInteger: """ Return the length of the list ``name`` @@ -2887,7 +2952,7 @@ def lpop( self, name: KeyT, count: Optional[int] = None, - ) -> Union[Awaitable[Union[str, List, None]], Union[str, List, None]]: + ) -> ResponseTypeLPopRPop: """ Removes and returns the first elements of the list ``name``. @@ -2902,7 +2967,7 @@ def lpop( else: return self.execute_command("LPOP", name) - def lpush(self, name: KeyT, *values: FieldT) -> Union[Awaitable[int], int]: + def lpush(self, name: KeyT, *values: FieldT) -> ResponseTypeInteger: """ Push ``values`` onto the head of the list ``name`` @@ -2910,7 +2975,7 @@ def lpush(self, name: KeyT, *values: FieldT) -> Union[Awaitable[int], int]: """ return self.execute_command("LPUSH", name, *values) - def lpushx(self, name: KeyT, *values: FieldT) -> Union[Awaitable[int], int]: + def lpushx(self, name: KeyT, *values: FieldT) -> ResponseTypeInteger: """ Push ``value`` onto the head of the list ``name`` if ``name`` exists @@ -2918,7 +2983,7 @@ def lpushx(self, name: KeyT, *values: FieldT) -> Union[Awaitable[int], int]: """ return self.execute_command("LPUSHX", name, *values) - def lrange(self, name: KeyT, start: int, end: int) -> Union[Awaitable[list], list]: + def lrange(self, name: KeyT, start: int, end: int) -> ResponseTypeListOfAnyStrings: """ Return a slice of the list ``name`` between position ``start`` and ``end`` @@ -2930,7 +2995,7 @@ def lrange(self, name: KeyT, start: int, end: int) -> Union[Awaitable[list], lis """ return self.execute_command("LRANGE", name, start, end, keys=[name]) - def lrem(self, name: KeyT, count: int, value: str) -> Union[Awaitable[int], int]: + def lrem(self, name: KeyT, count: int, value: str) -> ResponseTypeInteger: """ Remove the first ``count`` occurrences of elements equal to ``value`` from the list stored at ``name``. @@ -2944,7 +3009,7 @@ def lrem(self, name: KeyT, count: int, value: str) -> Union[Awaitable[int], int] """ return self.execute_command("LREM", name, count, value) - def lset(self, name: KeyT, index: int, value: str) -> Union[Awaitable[str], str]: + def lset(self, name: KeyT, index: int, value: str) -> ResponseTypeBoolean: """ Set element at ``index`` of list ``name`` to ``value`` @@ -2952,7 +3017,7 @@ def lset(self, name: KeyT, index: int, value: str) -> Union[Awaitable[str], str] """ return self.execute_command("LSET", name, index, value) - def ltrim(self, name: KeyT, start: int, end: int) -> Union[Awaitable[str], str]: + def ltrim(self, name: KeyT, start: int, end: int) -> ResponseTypeBoolean: """ Trim the list ``name``, removing all values not within the slice between ``start`` and ``end`` @@ -2968,7 +3033,7 @@ def rpop( self, name: KeyT, count: Optional[int] = None, - ) -> Union[Awaitable[Union[str, List, None]], Union[str, List, None]]: + ) -> ResponseTypeLPopRPop: """ Removes and returns the last elements of the list ``name``. @@ -2983,7 +3048,7 @@ def rpop( else: return self.execute_command("RPOP", name) - def rpoplpush(self, src: KeyT, dst: KeyT) -> Union[Awaitable[str], str]: + def rpoplpush(self, src: KeyT, dst: KeyT) -> ResponseTypeOptionalAnyString: """ RPOP a value off of the ``src`` list and atomically LPUSH it on to the ``dst`` list. Returns the value. @@ -2992,7 +3057,7 @@ def rpoplpush(self, src: KeyT, dst: KeyT) -> Union[Awaitable[str], str]: """ return self.execute_command("RPOPLPUSH", src, dst) - def rpush(self, name: KeyT, *values: FieldT) -> Union[Awaitable[int], int]: + def rpush(self, name: KeyT, *values: FieldT) -> ResponseTypeInteger: """ Push ``values`` onto the tail of the list ``name`` @@ -3000,7 +3065,7 @@ def rpush(self, name: KeyT, *values: FieldT) -> Union[Awaitable[int], int]: """ return self.execute_command("RPUSH", name, *values) - def rpushx(self, name: KeyT, *values: str) -> Union[Awaitable[int], int]: + def rpushx(self, name: KeyT, *values: str) -> ResponseTypeInteger: """ Push ``value`` onto the tail of the list ``name`` if ``name`` exists @@ -3015,7 +3080,7 @@ def lpos( rank: Optional[int] = None, count: Optional[int] = None, maxlen: Optional[int] = None, - ) -> Union[str, List, None]: + ) -> ResponseTypeOptionalInteger: """ Get position of ``value`` within the list ``name`` @@ -3064,7 +3129,7 @@ def sort( alpha: bool = False, store: Optional[str] = None, groups: Optional[bool] = False, - ) -> Union[List, int]: + ) -> ResponseTypeListOfAnyStrings: """ Sort and return the list, set or sorted set at ``name``. @@ -3135,7 +3200,7 @@ def sort_ro( get: Optional[List[str]] = None, desc: bool = False, alpha: bool = False, - ) -> list: + ) -> ResponseTypeListOfAnyStrings: """ Returns the elements contained in the list, set or sorted set at key. (read-only variant of the SORT command) @@ -6923,11 +6988,31 @@ def function_stats(self) -> Union[Awaitable[List], List]: class DataAccessCommands( - BasicKeyCommands, + BasicKeyCommands[ + BooleanType, + FloatType, + IntegerType, + OptionalEncodedStringType, + StrAlgoResultType, + ResponseTypeAnyString, + ResponseTypeOptionalAnyString, + ResponseTypeListOfAnyStrings, + ResponseTypeListOfAnyOptionalStrings, + ], HyperlogCommands, HashCommands, GeoCommands, - ListCommands, + ListCommands[ + BooleanType, + IntegerType, + OptionalIntegerType, + ResponseTypeAnyString, + ResponseTypeOptionalAnyString, + ResponseTypeListOfAnyStrings, + ResponseTypeOptionalListOfAnyStrings, + ResponseTypeOptionalLMPop, + ResponseTypeLPopRPop, + ], ScanCommands, SetCommands, StreamCommands, @@ -6940,11 +7025,31 @@ class DataAccessCommands( class AsyncDataAccessCommands( - AsyncBasicKeyCommands, + AsyncBasicKeyCommands[ + Awaitable[BooleanType], + Awaitable[FloatType], + Awaitable[IntegerType], + Awaitable[OptionalEncodedStringType], + Awaitable[StrAlgoResultType], + Awaitable[ResponseTypeAnyString], + Awaitable[ResponseTypeOptionalAnyString], + Awaitable[ResponseTypeListOfAnyStrings], + Awaitable[ResponseTypeListOfAnyOptionalStrings], + ], AsyncHyperlogCommands, AsyncHashCommands, AsyncGeoCommands, - AsyncListCommands, + AsyncListCommands[ + Awaitable[BooleanType], + Awaitable[IntegerType], + Awaitable[OptionalIntegerType], + Awaitable[ResponseTypeAnyString], + Awaitable[ResponseTypeOptionalAnyString], + Awaitable[ResponseTypeListOfAnyStrings], + Awaitable[ResponseTypeOptionalListOfAnyStrings], + Awaitable[ResponseTypeOptionalLMPop], + Awaitable[ResponseTypeLPopRPop], + ], AsyncScanCommands, AsyncSetCommands, AsyncStreamCommands, @@ -6959,7 +7064,15 @@ class AsyncDataAccessCommands( class CoreCommands( ACLCommands, ClusterCommands, - DataAccessCommands, + DataAccessCommands[ + ResponseTypeAnyString, + ResponseTypeOptionalAnyString, + ResponseTypeListOfAnyStrings, + ResponseTypeListOfAnyOptionalStrings, + ResponseTypeOptionalListOfAnyStrings, + ResponseTypeOptionalLMPop, + ResponseTypeLPopRPop, + ], ManagementCommands, ModuleCommands, PubSubCommands, @@ -6975,7 +7088,15 @@ class CoreCommands( class AsyncCoreCommands( AsyncACLCommands, AsyncClusterCommands, - AsyncDataAccessCommands, + AsyncDataAccessCommands[ + ResponseTypeAnyString, + ResponseTypeOptionalAnyString, + ResponseTypeListOfAnyStrings, + ResponseTypeListOfAnyOptionalStrings, + ResponseTypeOptionalListOfAnyStrings, + ResponseTypeOptionalLMPop, + ResponseTypeLPopRPop, + ], AsyncManagementCommands, AsyncModuleCommands, AsyncPubSubCommands, diff --git a/redis/typing.py b/redis/typing.py index ede5385e2d..094229937e 100644 --- a/redis/typing.py +++ b/redis/typing.py @@ -1,5 +1,3 @@ -# from __future__ import annotations - from datetime import datetime, timedelta from typing import ( TYPE_CHECKING, @@ -7,8 +5,10 @@ Awaitable, Iterable, Mapping, + Optional, Protocol, Type, + TypedDict, TypeVar, Union, ) @@ -48,6 +48,169 @@ ExceptionMappingT = Mapping[str, Union[Type[Exception], Mapping[str, Type[Exception]]]] +BooleanType = bool +IntegerType = int +OptionalIntegerType = Optional[IntegerType] +FloatType = float + +DecodedStringType = str +EncodedStringType = bytes +AnyStringType = Union[DecodedStringType, EncodedStringType] +OptionalDecodedStringType = Optional[DecodedStringType] +OptionalEncodedStringType = Optional[EncodedStringType] +OptionalAnyStringType = Union[OptionalDecodedStringType, OptionalEncodedStringType] + +ListOfDecodedStringsType = list[DecodedStringType] +ListOfEncodedStringsType = list[EncodedStringType] +OptionalListOfDecodedStringsType = Optional[ListOfDecodedStringsType] +OptionalListOfEncodedStringsType = Optional[ListOfEncodedStringsType] +ListOfAnyStringsType = Union[ListOfDecodedStringsType, ListOfEncodedStringsType] +OptionalListOfAnyStringsType = Union[ + OptionalListOfDecodedStringsType, + OptionalListOfEncodedStringsType, +] + +ListOfOptionalDecodedStringsType = list[OptionalDecodedStringType] +ListOfOptionalEncodedStringsType = list[OptionalEncodedStringType] +ListOfAnyOptionalStringsType = Union[ + ListOfOptionalDecodedStringsType, + ListOfOptionalEncodedStringsType, +] + +LPopRPopDecodedReturnType = Union[ + DecodedStringType, # Single value when count not specified + ListOfDecodedStringsType, # List when count is specified + None, # None when list is empty +] +LPopRPopEncodedReturnType = Union[ + EncodedStringType, # Single value when count not specified + ListOfEncodedStringsType, # List when count is specified + None, # None when list is empty +] + +# lpop / rpop can return single value, list, or None +LPopRPopAnyReturnType = Union[ + LPopRPopDecodedReturnType, + LPopRPopEncodedReturnType, +] + +# blmpop / lmpop return types +# Returns a list containing [key_name, [values...]] or None +# PyCharm doesn't like use of Optional here +LMPopDecodedReturnType = Union[ + list[ + Union[ + DecodedStringType, # key_name + list[DecodedStringType], # [values, ...] + ], + ], + None, # or None +] +LMPopEncodedReturnType = Union[ + list[ + Union[ + EncodedStringType, # key_name + list[EncodedStringType], # [values, ...] + ], + ], + None, # or None +] +LMPopAnyReturnType = Union[LMPopDecodedReturnType, LMPopEncodedReturnType] + +# STRALGO return types +# Represents the ranges, e.g., (4, 7) +PositionRange = tuple[IntegerType, IntegerType] + +# Represents a match entry when WITHMATCHLEN is False +# Example: [(4, 7), (5, 8)] +MatchSequence = list[PositionRange] + +# Represents a match entry when WITHMATCHLEN is True +# Example: [4, (4, 7), (5, 8)] <-- First item is int (len), rest are ranges +MatchSequenceWithLen = list[Union[IntegerType, PositionRange]] + + +class StrAlgoIdxResponse(TypedDict): + """ + Return type when IDX=True (No WITHMATCHLEN) + Example: {'matches': [[(4, 7), (5, 8)]], 'len': 6} + """ + + matches: list[MatchSequence] + len: IntegerType + + +class StrAlgoIdxWithLenResponse(TypedDict): + """ + Return type when IDX=True AND WITHMATCHLEN=True + Example: {'matches': [[4, (4, 7), (5, 8)]], 'len': 6} + """ + + matches: list[MatchSequenceWithLen] + len: IntegerType + + +StrAlgoResultType = Union[ + DecodedStringType, # str (default) + IntegerType, # int (LEN argument) + StrAlgoIdxResponse, # dict (IDX argument) + StrAlgoIdxWithLenResponse, # dict (IDX + WITHMATCHLEN argument) +] + + +ResponseTypeBoolean = TypeVar( + "ResponseTypeBoolean", + bound=Awaitable[BooleanType] | BooleanType, +) +ResponseTypeInteger = TypeVar( + "ResponseTypeInteger", + bound=Awaitable[IntegerType] | IntegerType, +) +ResponseTypeFloat = TypeVar( + "ResponseTypeFloat", + bound=Awaitable[FloatType] | FloatType, +) +ResponseTypeAnyString = TypeVar( + "ResponseTypeAnyString", + bound=Awaitable[AnyStringType] | AnyStringType, +) +ResponseTypeOptionalEncodedString = TypeVar( + "ResponseTypeOptionalEncodedString", + bound=Awaitable[OptionalEncodedStringType] | OptionalEncodedStringType, +) +ResponseTypeOptionalAnyString = TypeVar( + "ResponseTypeOptionalAnyString", + bound=Awaitable[OptionalAnyStringType] | OptionalAnyStringType, +) +ResponseTypeListOfAnyStrings = TypeVar( + "ResponseTypeListOfAnyStrings", + bound=Awaitable[ListOfAnyStringsType] | ListOfAnyStringsType, +) +ResponseTypeListOfAnyOptionalStrings = TypeVar( + "ResponseTypeListOfAnyOptionalStrings", + bound=Awaitable[ListOfAnyOptionalStringsType] | ListOfAnyOptionalStringsType, +) +ResponseTypeOptionalInteger = TypeVar( + "ResponseTypeOptionalInteger", + bound=Awaitable[OptionalIntegerType] | OptionalIntegerType, +) +ResponseTypeOptionalListOfAnyStrings = TypeVar( + "ResponseTypeOptionalListOfAnyStrings", + bound=Awaitable[OptionalListOfAnyStringsType] | OptionalListOfAnyStringsType, +) +ResponseTypeLPopRPop = TypeVar( + "ResponseTypeLPopRPop", + bound=Awaitable[LPopRPopAnyReturnType] | LPopRPopAnyReturnType, +) +ResponseTypeOptionalLMPop = TypeVar( + "ResponseTypeOptionalLMPop", + bound=Awaitable[LMPopAnyReturnType] | LMPopAnyReturnType, +) +ResponseTypeStrAlgoResult = TypeVar( + "ResponseTypeStrAlgoResult", + bound=Awaitable[StrAlgoResultType] | StrAlgoResultType, +) + class CommandsProtocol(Protocol): def execute_command(self, *args, **options) -> ResponseT: ... @@ -55,3 +218,35 @@ def execute_command(self, *args, **options) -> ResponseT: ... class ClusterCommandsProtocol(CommandsProtocol): encoder: "Encoder" + + +if TYPE_CHECKING: + from redis.client import Redis + + RedisEncoded = Redis[ + EncodedStringType, + OptionalEncodedStringType, + ListOfEncodedStringsType, + ListOfOptionalEncodedStringsType, + OptionalListOfEncodedStringsType, + LMPopEncodedReturnType, + LPopRPopEncodedReturnType, + ] + RedisDecoded = Redis[ + DecodedStringType, + OptionalDecodedStringType, + ListOfDecodedStringsType, + ListOfOptionalDecodedStringsType, + OptionalListOfDecodedStringsType, + LMPopDecodedReturnType, + LPopRPopDecodedReturnType, + ] + RedisEncodedOrDecoded = Redis[ + AnyStringType, + OptionalAnyStringType, + ListOfAnyStringsType, + ListOfAnyOptionalStringsType, + OptionalListOfAnyStringsType, + LMPopAnyReturnType, + LPopRPopAnyReturnType, + ] diff --git a/whitelist.py b/whitelist.py index 29cd529e4d..e1f1a4f8d7 100644 --- a/whitelist.py +++ b/whitelist.py @@ -16,3 +16,9 @@ AsyncConnectionPool # unused import (//data/repos/redis/redis-py/redis/typing.py:9) AsyncRedis # unused import (//data/repos/redis/redis-py/redis/commands/core.py:49) TargetNodesT # unused import (//data/repos/redis/redis-py/redis/commands/cluster.py:46) +RedisDecoded # unused import (redis/client.py:86) - used in TYPE_CHECKING for type annotations +RedisEncoded # unused import (redis/client.py:87) - used in TYPE_CHECKING for type annotations +RedisEncodedOrDecoded # unused import (redis/client.py:88) - used in TYPE_CHECKING for type annotations +RedisDecoded # unused import (redis/asyncio/client.py:109) - used in TYPE_CHECKING for type annotations +RedisEncoded # unused import (redis/asyncio/client.py:110) - used in TYPE_CHECKING for type annotations +RedisEncodedOrDecoded # unused import (redis/asyncio/client.py:111) - used in TYPE_CHECKING for type annotations