Skip to content

Commit 959ebd8

Browse files
[v3-2-test] Fix double-serialization issue by unwrapping serialized kwargs in encode_trigger (#64626) (#64642)
(cherry picked from commit d292e0e) Co-authored-by: Jason(Zhe-You) Liu <68415893+jason810496@users.noreply.github.com>
1 parent d1d2416 commit 959ebd8

4 files changed

Lines changed: 281 additions & 1 deletion

File tree

airflow-core/src/airflow/models/trigger.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
from airflow._shared.timezones import timezone
3333
from airflow.assets.manager import AssetManager
3434
from airflow.configuration import conf
35-
from airflow.models import Callback
3635
from airflow.models.asset import AssetWatcherModel
3736
from airflow.models.base import Base
3837
from airflow.models.taskinstance import TaskInstance
@@ -210,6 +209,8 @@ def bulk_fetch(cls, ids: Iterable[int], session: Session = NEW_SESSION) -> dict[
210209
@provide_session
211210
def fetch_trigger_ids_with_non_task_associations(cls, session: Session = NEW_SESSION) -> set[str]:
212211
"""Fetch all trigger IDs actively associated with non-task entities like assets and callbacks."""
212+
from airflow.models.callback import Callback
213+
213214
query = select(AssetWatcherModel.trigger_id).union_all(
214215
select(Callback.trigger_id).where(Callback.trigger_id.is_not(None))
215216
)
@@ -408,6 +409,8 @@ def get_sorted_triggers(
408409
:param queues: The optional set of trigger queues to filter triggers by.
409410
:param session: The database session.
410411
"""
412+
from airflow.models.callback import Callback
413+
411414
result: list[Row[Any]] = []
412415

413416
# Add triggers associated to callbacks first, then tasks, then assets

airflow-core/src/airflow/serialization/encoders.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,14 @@ def _ensure_serialized(d):
162162
if isinstance(trigger, dict):
163163
classpath = trigger["classpath"]
164164
kwargs = trigger["kwargs"]
165+
# unwrap any kwargs that are themselves serialized objects, to avoid double-serialization in the trigger's own serialize() method.
166+
unwrapped = {}
167+
for k, v in kwargs.items():
168+
if isinstance(v, dict) and Encoding.TYPE in v:
169+
unwrapped[k] = BaseSerialization.deserialize(v)
170+
else:
171+
unwrapped[k] = v
172+
kwargs = unwrapped
165173
else:
166174
classpath, kwargs = trigger.serialize()
167175
return {

airflow-core/tests/unit/dag_processing/test_collection.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,94 @@ def test_add_asset_trigger_references(
182182
asset_model = session.scalars(select(AssetModel)).one()
183183
assert len(asset_model.triggers) == expected_num_triggers
184184

185+
@pytest.mark.usefixtures("testing_dag_bundle")
186+
def test_add_asset_trigger_references_hash_consistency(self, dag_maker, session):
187+
"""Trigger hash from the DAG-parsed path must equal the hash computed
188+
from the DB-stored Trigger row. A mismatch causes the scheduler to
189+
recreate trigger rows on every heartbeat.
190+
"""
191+
from airflow.models.trigger import Trigger
192+
from airflow.serialization.encoders import encode_trigger
193+
from airflow.triggers.base import BaseEventTrigger
194+
195+
trigger = FileDeleteTrigger(filepath="/tmp/test.txt", poke_interval=5.0)
196+
asset = Asset(
197+
"test_hash_consistency_asset",
198+
watchers=[AssetWatcher(name="file_watcher", trigger=trigger)],
199+
)
200+
201+
with dag_maker(dag_id="test_hash_consistency_dag", schedule=[asset]) as dag:
202+
EmptyOperator(task_id="mytask")
203+
204+
dags = {dag.dag_id: LazyDeserializedDAG.from_dag(dag)}
205+
orm_dags = DagModelOperation(dags, "testing", None).add_dags(session=session)
206+
orm_dags[dag.dag_id].is_paused = False
207+
208+
asset_op = AssetModelOperation.collect(dags)
209+
orm_assets = asset_op.sync_assets(session=session)
210+
session.flush()
211+
212+
asset_op.add_dag_asset_references(orm_dags, orm_assets, session=session)
213+
asset_op.activate_assets_if_possible(orm_assets.values(), session=session)
214+
asset_op.add_asset_trigger_references(orm_assets, session=session)
215+
session.flush()
216+
217+
# DAG-side hash (same computation as add_asset_trigger_references line 1025)
218+
encoded = encode_trigger(trigger)
219+
dag_hash = BaseEventTrigger.hash(encoded["classpath"], encoded["kwargs"])
220+
221+
# DB-side: expire and re-load the Trigger row to force a real DB read
222+
asset_model = session.scalars(select(AssetModel)).one()
223+
assert len(asset_model.triggers) == 1
224+
orm_trigger = asset_model.triggers[0]
225+
trigger_id = orm_trigger.id
226+
session.expire(orm_trigger)
227+
reloaded = session.get(Trigger, trigger_id)
228+
229+
# DB-side hash (same computation as add_asset_trigger_references line 1033)
230+
db_hash = BaseEventTrigger.hash(reloaded.classpath, reloaded.kwargs)
231+
232+
assert dag_hash == db_hash
233+
234+
@pytest.mark.usefixtures("testing_dag_bundle")
235+
def test_add_asset_trigger_references_idempotent(self, dag_maker, session):
236+
"""Calling add_asset_trigger_references twice with the same trigger
237+
must not create duplicate rows.
238+
"""
239+
from airflow.models.trigger import Trigger
240+
241+
trigger = FileDeleteTrigger(filepath="/tmp/test.txt", poke_interval=5.0)
242+
asset = Asset(
243+
"test_idempotent_asset",
244+
watchers=[AssetWatcher(name="file_watcher", trigger=trigger)],
245+
)
246+
247+
with dag_maker(dag_id="test_idempotent_dag", schedule=[asset]) as dag:
248+
EmptyOperator(task_id="mytask")
249+
250+
dags = {dag.dag_id: LazyDeserializedDAG.from_dag(dag)}
251+
orm_dags = DagModelOperation(dags, "testing", None).add_dags(session=session)
252+
orm_dags[dag.dag_id].is_paused = False
253+
254+
asset_op = AssetModelOperation.collect(dags)
255+
orm_assets = asset_op.sync_assets(session=session)
256+
session.flush()
257+
258+
asset_op.add_dag_asset_references(orm_dags, orm_assets, session=session)
259+
asset_op.activate_assets_if_possible(orm_assets.values(), session=session)
260+
261+
# First call — creates the trigger
262+
asset_op.add_asset_trigger_references(orm_assets, session=session)
263+
session.flush()
264+
count_after_first = session.scalar(select(func.count(Trigger.id)))
265+
266+
# Second call — should be a no-op (hashes match, no diff)
267+
asset_op.add_asset_trigger_references(orm_assets, session=session)
268+
session.flush()
269+
count_after_second = session.scalar(select(func.count(Trigger.id)))
270+
271+
assert count_after_first == count_after_second
272+
185273
@pytest.mark.parametrize(
186274
("schedule", "model", "columns", "expected"),
187275
[
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
from __future__ import annotations
18+
19+
import pytest
20+
from sqlalchemy import delete
21+
22+
from airflow.models.trigger import Trigger
23+
from airflow.providers.standard.triggers.file import FileDeleteTrigger
24+
from airflow.serialization.encoders import encode_trigger
25+
from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding
26+
from airflow.triggers.base import BaseEventTrigger
27+
28+
pytest.importorskip("airflow.providers.apache.kafka")
29+
from airflow.providers.apache.kafka.triggers.await_message import AwaitMessageTrigger
30+
31+
# Trigger fixtures covering primitive-only kwargs (FileDeleteTrigger) and
32+
# non-primitive kwargs like tuple/dict (AwaitMessageTrigger).
33+
_TRIGGER_PARAMS = [
34+
pytest.param(
35+
FileDeleteTrigger(filepath="/tmp/test.txt", poke_interval=5.0),
36+
id="primitive_kwargs_only",
37+
),
38+
pytest.param(AwaitMessageTrigger(topics=()), id="empty_tuple"),
39+
pytest.param(
40+
AwaitMessageTrigger(topics=("fizz_buzz",), poll_timeout=1.0, commit_offset=True),
41+
id="single_topic_tuple",
42+
),
43+
pytest.param(
44+
AwaitMessageTrigger(
45+
topics=["t1", "t2"],
46+
apply_function="my.module.func",
47+
apply_function_args=["a", "b"],
48+
apply_function_kwargs={"key": "value"},
49+
kafka_config_id="my_kafka",
50+
poll_interval=2,
51+
poll_timeout=3,
52+
),
53+
id="all_non_primitive_kwargs",
54+
),
55+
]
56+
57+
58+
class TestEncodeTrigger:
59+
"""Tests for encode_trigger round-trip correctness.
60+
61+
When a serialized DAG with asset-watcher triggers is re-serialized
62+
(e.g. in ``add_asset_trigger_references``), ``encode_trigger`` receives
63+
a dict whose kwargs already contain wrapped values like
64+
``{__type: tuple, __var: [...]}``. The fix ensures these are unwrapped
65+
before re-serialization to prevent double-wrapping.
66+
"""
67+
68+
def test_encode_from_trigger_object(self):
69+
"""Non-primitive kwargs are properly serialized from a trigger object."""
70+
trigger = AwaitMessageTrigger(topics=())
71+
result = encode_trigger(trigger)
72+
73+
assert (
74+
result["classpath"] == "airflow.providers.apache.kafka.triggers.await_message.AwaitMessageTrigger"
75+
)
76+
# tuple kwarg is wrapped by BaseSerialization
77+
assert result["kwargs"]["topics"] == {Encoding.TYPE: DAT.TUPLE, Encoding.VAR: []}
78+
# Primitives pass through as-is
79+
assert result["kwargs"]["poll_timeout"] == 1
80+
assert result["kwargs"]["commit_offset"] is True
81+
82+
def test_encode_file_delete_trigger(self):
83+
"""Primitive-only kwargs pass through without wrapping."""
84+
trigger = FileDeleteTrigger(filepath="/tmp/test.txt", poke_interval=10.0)
85+
result = encode_trigger(trigger)
86+
87+
assert result["classpath"] == "airflow.providers.standard.triggers.file.FileDeleteTrigger"
88+
assert result["kwargs"]["filepath"] == "/tmp/test.txt"
89+
assert result["kwargs"]["poke_interval"] == 10.0
90+
91+
@pytest.mark.parametrize("trigger", _TRIGGER_PARAMS)
92+
def test_re_encode_is_idempotent(self, trigger):
93+
"""Encoding the output of encode_trigger again must not double-wrap kwargs."""
94+
first = encode_trigger(trigger)
95+
second = encode_trigger(first)
96+
97+
assert first == second
98+
99+
@pytest.mark.parametrize("trigger", _TRIGGER_PARAMS)
100+
def test_multiple_round_trips_are_stable(self, trigger):
101+
"""Encoding the same trigger dict many times remains idempotent."""
102+
result = encode_trigger(trigger)
103+
for _ in range(5):
104+
result = encode_trigger(result)
105+
106+
assert result == encode_trigger(trigger)
107+
108+
109+
@pytest.mark.db_test
110+
class TestTriggerHashConsistency:
111+
"""Verify ``BaseEventTrigger.hash`` produces the same value for kwargs
112+
from the DAG-parsed path and kwargs read back from the database.
113+
114+
This mirrors the comparison in
115+
``AssetModelOperation.add_asset_trigger_references``
116+
(``airflow-core/src/airflow/dag_processing/collection.py``), where:
117+
118+
* **DAG side** — ``BaseEventTrigger.hash(classpath, encode_trigger(watcher.trigger)["kwargs"])``
119+
* **DB side** — ``BaseEventTrigger.hash(trigger.classpath, trigger.kwargs)``
120+
where the ``Trigger`` row was persisted with ``encrypt_kwargs`` and
121+
read back via ``_decrypt_kwargs``.
122+
123+
If the hashes diverge, the scheduler sees phantom diffs and keeps
124+
recreating trigger rows on every heartbeat.
125+
"""
126+
127+
@pytest.fixture(autouse=True)
128+
def _clean_triggers(self, session):
129+
session.execute(delete(Trigger))
130+
session.commit()
131+
yield
132+
session.execute(delete(Trigger))
133+
session.commit()
134+
135+
@pytest.mark.parametrize("trigger", _TRIGGER_PARAMS)
136+
def test_hash_matches_after_db_round_trip(self, trigger, session):
137+
"""Hash from DAG-parsed kwargs equals hash from a DB-persisted Trigger."""
138+
encoded = encode_trigger(trigger)
139+
classpath = encoded["classpath"]
140+
dag_kwargs = encoded["kwargs"]
141+
142+
# DAG side hash — what add_asset_trigger_references computes
143+
dag_hash = BaseEventTrigger.hash(classpath, dag_kwargs)
144+
145+
# Persist to DB (same as add_asset_trigger_references lines 1073-1074)
146+
trigger_row = Trigger(classpath=classpath, kwargs=dag_kwargs)
147+
session.add(trigger_row)
148+
session.flush()
149+
150+
# Force a real DB read — expire the instance and re-select
151+
trigger_id = trigger_row.id
152+
session.expire(trigger_row)
153+
reloaded = session.get(Trigger, trigger_id)
154+
155+
# DB side hash — what add_asset_trigger_references computes from ORM
156+
db_hash = BaseEventTrigger.hash(reloaded.classpath, reloaded.kwargs)
157+
158+
assert dag_hash == db_hash
159+
160+
@pytest.mark.parametrize("trigger", _TRIGGER_PARAMS)
161+
def test_hash_matches_after_re_encode_and_db_round_trip(self, trigger, session):
162+
"""Hash stays consistent when encode_trigger output is re-encoded
163+
(deserialized-DAG re-serialization path) before DB storage.
164+
"""
165+
re_encoded = encode_trigger(encode_trigger(trigger))
166+
classpath = re_encoded["classpath"]
167+
dag_kwargs = re_encoded["kwargs"]
168+
169+
dag_hash = BaseEventTrigger.hash(classpath, dag_kwargs)
170+
171+
trigger_row = Trigger(classpath=classpath, kwargs=dag_kwargs)
172+
session.add(trigger_row)
173+
session.flush()
174+
175+
trigger_id = trigger_row.id
176+
session.expire(trigger_row)
177+
reloaded = session.get(Trigger, trigger_id)
178+
179+
db_hash = BaseEventTrigger.hash(reloaded.classpath, reloaded.kwargs)
180+
181+
assert dag_hash == db_hash

0 commit comments

Comments
 (0)