Skip to content

Commit f033fc5

Browse files
jroachgolf84potiuk
authored andcommitted
[v2-11-test] Ensuring XCom return value can be mapped for dynamically-mapped @task_group's (#51668)
* xcom_arg expansion fix for v2-11 * Removing previously-added file, not needed for v2-11
1 parent 8d8d422 commit f033fc5

2 files changed

Lines changed: 61 additions & 0 deletions

File tree

airflow/decorators/task_group.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
ListOfDictsExpandInput,
3939
MappedArgument,
4040
)
41+
from airflow.models.mappedoperator import ensure_xcomarg_return_value
4142
from airflow.models.taskmixin import DAGNode
4243
from airflow.models.xcom_arg import XComArg
4344
from airflow.typing_compat import ParamSpec
@@ -134,6 +135,11 @@ def expand(self, **kwargs: OperatorExpandArgument) -> DAGNode:
134135
self._validate_arg_names("expand", kwargs)
135136
prevent_duplicates(self.partial_kwargs, kwargs, fail_reason="mapping already partial")
136137
expand_input = DictOfListsExpandInput(kwargs)
138+
139+
# Similar to @task, @task_group should not be "mappable" over an XCom with a custom key. This will
140+
# raise an exception, rather than having an ambiguous exception similar to the one found in #51109.
141+
ensure_xcomarg_return_value(expand_input.value)
142+
137143
return self._create_task_group(
138144
functools.partial(MappedTaskGroup, expand_input=expand_input),
139145
**self.partial_kwargs,
@@ -163,6 +169,11 @@ def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument) -> DAGNode:
163169
map_kwargs = (k for k in self.function_signature.parameters if k not in self.partial_kwargs)
164170

165171
expand_input = ListOfDictsExpandInput(kwargs)
172+
173+
# Similar to @task, @task_group should not be "mappable" over an XCom with a custom key. This will
174+
# raise an exception, rather than having an ambiguous exception similar to the one found in #51109.
175+
ensure_xcomarg_return_value(expand_input.value)
176+
166177
return self._create_task_group(
167178
functools.partial(MappedTaskGroup, expand_input=expand_input),
168179
**self.partial_kwargs,

tests/decorators/test_task_group.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,31 @@ def tg(a, b):
178178
assert saved == {"a": 1, "b": MappedArgument(input=tg._expand_input, key="b")}
179179

180180

181+
def test_expand_invalid_xcomarg_return_value():
182+
saved = {}
183+
184+
@dag(schedule=None, start_date=pendulum.datetime(2022, 1, 1))
185+
def pipeline():
186+
@task
187+
def t():
188+
return {"values": ["value_1", "value_2"]}
189+
190+
@task_group()
191+
def tg(a, b):
192+
saved["a"] = a
193+
saved["b"] = b
194+
195+
tg.partial(a=1).expand(b=t()["values"])
196+
197+
with pytest.raises(ValueError) as ctx:
198+
pipeline()
199+
200+
assert (
201+
str(ctx.value)
202+
== "cannot map over XCom with custom key 'values' from <Task(_PythonDecoratedOperator): t>"
203+
)
204+
205+
181206
def test_expand_kwargs_no_wildcard():
182207
@dag(schedule=None, start_date=pendulum.datetime(2022, 1, 1))
183208
def pipeline():
@@ -262,6 +287,31 @@ def t2():
262287
assert "missing upstream values: ['b']" not in caplog.text
263288

264289

290+
def test_expand_kwargs_invalid_xcomarg_return_value():
291+
saved = {}
292+
293+
@dag(schedule=None, start_date=pendulum.datetime(2022, 1, 1))
294+
def pipeline():
295+
@task
296+
def t():
297+
return {"values": [{"b": 2}, {"b": 3}]}
298+
299+
@task_group()
300+
def tg(a, b):
301+
saved["a"] = a
302+
saved["b"] = b
303+
304+
tg.partial(a=1).expand_kwargs(t()["values"])
305+
306+
with pytest.raises(ValueError) as ctx:
307+
pipeline()
308+
309+
assert (
310+
str(ctx.value)
311+
== "cannot map over XCom with custom key 'values' from <Task(_PythonDecoratedOperator): t>"
312+
)
313+
314+
265315
def test_override_dag_default_args():
266316
@dag(
267317
dag_id="test_dag",

0 commit comments

Comments
 (0)