diff --git a/config-example.yaml b/config-example.yaml index 3f9c8c2..702eb80 100644 --- a/config-example.yaml +++ b/config-example.yaml @@ -9,6 +9,12 @@ postgres: user: "tma_user" password: "tma_password" db: "tma_db" + # pool_size: 30 + # pool_timeout: 30 + # pool_recycle: 3600 + # max_overflow: 20 + # pool_pre_ping: true + # echo_pool: false auth: secret_key: "secret" diff --git a/src/infrastructure/config.py b/src/infrastructure/config.py index 624c93a..f1e0341 100644 --- a/src/infrastructure/config.py +++ b/src/infrastructure/config.py @@ -11,6 +11,12 @@ class PostgresConfig(BaseModel): password: str db: str echo: bool = False + pool_size: int = 30 + pool_timeout: int = 30 + pool_recycle: int = 3600 + max_overflow: int = 20 + pool_pre_ping: bool = True + echo_pool: bool = False @property def url(self) -> str: diff --git a/src/infrastructure/db/factory.py b/src/infrastructure/db/factory.py index 5d63eeb..99e5b71 100644 --- a/src/infrastructure/db/factory.py +++ b/src/infrastructure/db/factory.py @@ -14,20 +14,16 @@ def create_pool(db_config: PostgresConfig) -> async_sessionmaker[AsyncSession]: return create_session_maker(engine) -def create_engine( - db_config: PostgresConfig, - pool_size: int = 30, - pool_timeout: int = 30, - pool_recycle: int = 3600, - max_overflow: int = 20, -) -> AsyncEngine: +def create_engine(db_config: PostgresConfig) -> AsyncEngine: return create_async_engine( url=make_url(db_config.url), echo=db_config.echo, - pool_size=pool_size, - pool_timeout=pool_timeout, - pool_recycle=pool_recycle, - max_overflow=max_overflow, + pool_size=db_config.pool_size, + pool_timeout=db_config.pool_timeout, + pool_recycle=db_config.pool_recycle, + max_overflow=db_config.max_overflow, + pool_pre_ping=db_config.pool_pre_ping, + echo_pool=db_config.echo_pool, ) diff --git a/tests/unit/infrastructure/test_config.py b/tests/unit/infrastructure/test_config.py index 93dc16f..393dabe 100644 --- a/tests/unit/infrastructure/test_config.py +++ b/tests/unit/infrastructure/test_config.py @@ -104,6 +104,40 @@ def test_missing_required_fields(self): with pytest.raises(ValidationError): PostgresConfig() + def test_pool_defaults(self): + config = PostgresConfig( + host="localhost", port=5432, user="user", password="pass", db="db" + ) + + assert config.pool_size == 30 + assert config.pool_timeout == 30 + assert config.pool_recycle == 3600 + assert config.max_overflow == 20 + assert config.pool_pre_ping is True + assert config.echo_pool is False + + def test_pool_custom_values(self): + config = PostgresConfig( + host="localhost", + port=5432, + user="user", + password="pass", + db="db", + pool_size=10, + pool_timeout=15, + pool_recycle=1800, + max_overflow=5, + pool_pre_ping=False, + echo_pool=True, + ) + + assert config.pool_size == 10 + assert config.pool_timeout == 15 + assert config.pool_recycle == 1800 + assert config.max_overflow == 5 + assert config.pool_pre_ping is False + assert config.echo_pool is True + @pytest.mark.parametrize( "port,echo,should_raise", [