diff --git a/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py b/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py index 9680d7807edeb..fc2ad5e44003c 100644 --- a/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py +++ b/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py @@ -31,6 +31,7 @@ import traceback from collections.abc import Collection, Mapping, MutableMapping, Sequence from concurrent.futures import ProcessPoolExecutor +from functools import cache from typing import TYPE_CHECKING, Any from celery import Celery, Task, states as celery_states @@ -84,24 +85,24 @@ # Make it constant for unit test. CELERY_FETCH_ERR_MSG_HEADER = "Error fetching Celery task state" -celery_configuration = None +@cache +def get_celery_configuration() -> dict[str, Any]: + """Get the Celery configuration dictionary.""" + if conf.has_option("celery", "celery_config_options"): + return conf.getimport("celery", "celery_config_options") -@providers_configuration_loaded -def _get_celery_app() -> Celery: - """Init providers before importing the configuration, so the _SECRET and _CMD options work.""" - global celery_configuration + from airflow.providers.celery.executors.default_celery import DEFAULT_CELERY_CONFIG - if conf.has_option("celery", "celery_config_options"): - celery_configuration = conf.getimport("celery", "celery_config_options") - else: - from airflow.providers.celery.executors.default_celery import DEFAULT_CELERY_CONFIG + return DEFAULT_CELERY_CONFIG - celery_configuration = DEFAULT_CELERY_CONFIG +@providers_configuration_loaded +def _get_celery_app() -> Celery: + """Init providers before importing the configuration, so the _SECRET and _CMD options work.""" celery_app_name = conf.get("celery", "CELERY_APP_NAME") - return Celery(celery_app_name, config_source=celery_configuration) + return Celery(celery_app_name, config_source=get_celery_configuration()) app = _get_celery_app() diff --git a/providers/celery/tests/integration/celery/test_celery_executor.py b/providers/celery/tests/integration/celery/test_celery_executor.py index bf5a20a1dfb50..07c3a2c431c3b 100644 --- a/providers/celery/tests/integration/celery/test_celery_executor.py +++ b/providers/celery/tests/integration/celery/test_celery_executor.py @@ -79,7 +79,7 @@ def _prepare_app(broker_url=None, execute=None): execute_name = "execute_command" execute = execute or celery_executor_utils.execute_command.__wrapped__ - test_config = dict(celery_executor_utils.celery_configuration) + test_config = dict(celery_executor_utils.get_celery_configuration()) test_config.update({"broker_url": broker_url}) test_app = Celery(broker_url, config_source=test_config) test_execute = test_app.task(execute) @@ -168,7 +168,7 @@ def fake_execute_workload(command): run_id="abc", try_number=0, priority_weight=1, - queue=celery_executor_utils.celery_configuration["task_default_queue"], + queue=celery_executor_utils.get_celery_configuration()["task_default_queue"], executor_config=executor_config, ) keys = [ diff --git a/providers/celery/tests/unit/celery/executors/test_celery_executor.py b/providers/celery/tests/unit/celery/executors/test_celery_executor.py index 0a1a49345ed62..00b40d59a2a10 100644 --- a/providers/celery/tests/unit/celery/executors/test_celery_executor.py +++ b/providers/celery/tests/unit/celery/executors/test_celery_executor.py @@ -86,7 +86,7 @@ def _prepare_app(broker_url=None, execute=None): execute_name = "execute_command" execute = execute or celery_executor_utils.execute_command.__wrapped__ - test_config = dict(celery_executor_utils.celery_configuration) + test_config = dict(celery_executor_utils.get_celery_configuration()) test_config.update({"broker_url": broker_url}) test_app = Celery(broker_url, config_source=test_config) test_execute = test_app.task(execute)