Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions dagfactory/dagbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,7 @@ def make_task(operator: str, task_params: Dict[str, Any]) -> BaseOperator:
DagBuilder.adjust_general_task_params(task_params)

expand_kwargs: Dict[str, Union[Dict[str, Any], Any]] = {}
expand_kwargs_kwargs: List[Dict[str, Union[Dict[str, Any], Any]]] = {}
# expand available only in airflow >= 2.3.0
if (
utils.check_dict_key(task_params, "expand") or utils.check_dict_key(task_params, "partial")
Expand All @@ -442,11 +443,18 @@ def make_task(operator: str, task_params: Dict[str, Any]) -> BaseOperator:
if partial_kwargs and not utils.is_partial_duplicated(partial_kwargs, task_params):
task_params.update(partial_kwargs)

task: Union[BaseOperator, MappedOperator] = (
operator_obj(**task_params)
if not expand_kwargs
else operator_obj.partial(**task_params).expand(**expand_kwargs)
)
# expand_kwargs available only in airflow >= 2.4.0
if (
utils.check_dict_key(task_params, "expand_kwargs")
) and version.parse(AIRFLOW_VERSION) >= version.parse("2.4.0"):
expand_kwargs_kwargs = task_params["expand_kwargs"]
del task_params["expand_kwargs"]

task: Union[BaseOperator, MappedOperator] = operator_obj(**task_params)
if expand_kwargs:
task = operator_obj.partial(**task_params).expand(**expand_kwargs)
elif expand_kwargs_kwargs:
task = operator_obj.partial(**task_params).expand_kwargs(expand_kwargs_kwargs)
except Exception as err:
raise DagFactoryException(f"Failed to create {operator_obj} task") from err
return task
Expand Down Expand Up @@ -882,6 +890,12 @@ def build(self) -> Dict[str, Union[str, DAG]]:
raise DagFactoryConfigException("Dynamic task mapping available only in Airflow >= 2.3.0")
else:
task_conf = self.replace_expand_values(task_conf, tasks_dict)
if task_conf.get("expand_kwargs"):
if version.parse(AIRFLOW_VERSION) < version.parse("2.4.0"):
raise DagFactoryConfigException("Dynamic task mapping with multiple Parameter available only in Airflow >= 2.4.0")
# TODO
# else:
# task_conf = self.replace_expand_values(task_conf, tasks_dict)

task: Union[BaseOperator, MappedOperator] = DagBuilder.make_task(operator=operator, task_params=params)
tasks_dict[task.task_id]: BaseOperator = task
Expand Down
9 changes: 5 additions & 4 deletions dagfactory/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,10 +218,11 @@ def get_expand_partial_kwargs(

expand_kwargs: Dict[str, Union[Dict[str, Any], Any]] = {}
partial_kwargs: Dict[str, Union[Dict[str, Any], Any]] = {}
for expand_key, expand_value in task_params["expand"].items():
expand_kwargs[expand_key] = expand_value
# remove dag-factory specific parameter
del task_params["expand"]
if check_dict_key(task_params, "expand"):
for expand_key, expand_value in task_params["expand"].items():
expand_kwargs[expand_key] = expand_value
# remove dag-factory specific parameter
del task_params["expand"]
if check_dict_key(task_params, "partial"):
for partial_key, partial_value in task_params["partial"].items():
partial_kwargs[partial_key] = partial_value
Expand Down
22 changes: 22 additions & 0 deletions tests/test_dagbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -960,6 +960,28 @@ def test_dynamic_task_mapping():
assert isinstance(actual, MappedOperator)


def test_multi_parameter_dynamic_task_mapping():
td = dagbuilder.DagBuilder("test_dag", DAG_CONFIG_DYNAMIC_TASK_MAPPING, DEFAULT_CONFIG)
if version.parse(AIRFLOW_VERSION) < version.parse("2.4.0"):
error_message = "Dynamic task mapping with multiple Parameter available only in Airflow >= 2.4.0"
with pytest.raises(Exception, match=error_message):
td.build()
else:
operator = "airflow.operators.python_operator.PythonOperator"
task_params = {
"task_id": "process",
"python_callable_name": "expand_task",
"python_callable_file": os.path.realpath(__file__),
"partial": {"op_kwargs": {"test_id": "test"}},
"expand_kwargs": [
{"op_args": {"request_output": "request.output"}},
{"op_args": {"request_output": "temp"}}
],
}
actual = td.make_task(operator, task_params)
assert isinstance(actual, MappedOperator)


def test_replace_expand_string_with_xcom():
td = dagbuilder.DagBuilder("test_dag", DAG_CONFIG_DYNAMIC_TASK_MAPPING, DEFAULT_CONFIG)
if version.parse(AIRFLOW_VERSION) < version.parse("2.3.0"):
Expand Down
Loading