diff --git a/openlibrary/plugins/openlibrary/connection.py b/openlibrary/plugins/openlibrary/connection.py index 4c9deb15643..79204258f14 100644 --- a/openlibrary/plugins/openlibrary/connection.py +++ b/openlibrary/plugins/openlibrary/connection.py @@ -3,6 +3,7 @@ import datetime import json import logging +from typing import cast import web @@ -17,7 +18,7 @@ class ConnectionMiddleware: response_type = "json" - def __init__(self, conn): + def __init__(self, conn: 'client.Connection | ConnectionMiddleware'): self.conn = conn def get_auth_token(self): @@ -545,7 +546,7 @@ class HybridConnection(client.Connection): down the overhead of http calls present in case of remote connections. """ - def __init__(self, reader, writer): + def __init__(self, reader: client.Connection, writer: client.Connection): client.Connection.__init__(self) self.reader = reader self.writer = writer @@ -579,40 +580,39 @@ def _update_infobase_config(): server.update_config(config.infobase) -def create_local_connection(): - _update_infobase_config() - return client.connect(type='local', **web.config.db_parameters) - - -def create_remote_connection(): - return client.connect(type='remote', base_url=config.infobase_server) - - -def create_hybrid_connection(): - local = create_local_connection() - remote = create_remote_connection() - return HybridConnection(local, remote) +def OLConnection() -> client.Connection | ConnectionMiddleware: + """Create a connection to Open Library infobase server.""" + def create_local_connection(): + _update_infobase_config() + return cast( + client.LocalConnection, + client.connect(type='local', **web.config.db_parameters), + ) -def OLConnection(): - """Create a connection to Open Library infobase server.""" + def create_remote_connection(): + return cast( + client.RemoteConnection, + client.connect(type='remote', base_url=config.infobase_server), + ) - def create_connection(): - if config.get("connection_type") == "hybrid": - return create_hybrid_connection() - elif config.get('infobase_server'): - return create_remote_connection() - elif config.get("infobase", {}).get('db_parameters'): - return create_local_connection() - else: - raise Exception("db_parameters are not specified in the configuration") + conn: client.Connection | ConnectionMiddleware + if config.get("connection_type") == "hybrid": + conn = HybridConnection( + reader=create_local_connection(), + writer=create_remote_connection(), + ) + elif config.get('infobase_server'): + conn = create_remote_connection() + elif config.get("infobase", {}).get('db_parameters'): + conn = create_local_connection() + else: + raise Exception("db_parameters are not specified in the configuration") - conn = create_connection() if config.get('memcache_servers'): conn = MemcacheMiddleware(conn, config.get('memcache_servers')) if config.get('upstream_to_www_migration'): conn = MigrationMiddleware(conn) - conn = IAMiddleware(conn) - return conn + return IAMiddleware(conn)