diff --git a/providers/common/compat/src/airflow/providers/common/compat/standard/operators.py b/providers/common/compat/src/airflow/providers/common/compat/standard/operators.py index c8838fe3c2109..b3d35f1aa145f 100644 --- a/providers/common/compat/src/airflow/providers/common/compat/standard/operators.py +++ b/providers/common/compat/src/airflow/providers/common/compat/standard/operators.py @@ -46,5 +46,6 @@ except (ImportError, ModuleNotFoundError): from airflow.providers.standard.operators.python import get_current_context +from airflow.providers.common.compat.version_compat import BaseOperator -__all__ = ["PythonOperator", "_SERIALIZERS", "ShortCircuitOperator", "get_current_context"] +__all__ = ["BaseOperator", "PythonOperator", "_SERIALIZERS", "ShortCircuitOperator", "get_current_context"] diff --git a/providers/common/compat/src/airflow/providers/common/compat/version_compat.py b/providers/common/compat/src/airflow/providers/common/compat/version_compat.py index 48d122b669696..02d0f1ac162b0 100644 --- a/providers/common/compat/src/airflow/providers/common/compat/version_compat.py +++ b/providers/common/compat/src/airflow/providers/common/compat/version_compat.py @@ -33,3 +33,15 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]: AIRFLOW_V_3_0_PLUS = get_base_airflow_version_tuple() >= (3, 0, 0) +AIRFLOW_V_3_1_PLUS = get_base_airflow_version_tuple() >= (3, 1, 0) + +if AIRFLOW_V_3_0_PLUS: + from airflow.sdk import BaseOperator +else: + from airflow.models import BaseOperator + +__all__ = [ + "AIRFLOW_V_3_0_PLUS", + "AIRFLOW_V_3_1_PLUS", + "BaseOperator", +] diff --git a/providers/common/io/src/airflow/providers/common/io/operators/file_transfer.py b/providers/common/io/src/airflow/providers/common/io/operators/file_transfer.py index 0faec858d1c75..00b23bde989f1 100644 --- a/providers/common/io/src/airflow/providers/common/io/operators/file_transfer.py +++ b/providers/common/io/src/airflow/providers/common/io/operators/file_transfer.py @@ -20,7 +20,7 @@ from collections.abc import Sequence from typing import TYPE_CHECKING -from airflow.providers.common.io.version_compat import AIRFLOW_V_3_0_PLUS +from airflow.providers.common.io.version_compat import AIRFLOW_V_3_0_PLUS, BaseOperator if TYPE_CHECKING: from airflow.providers.openlineage.extractors import OperatorLineage @@ -28,10 +28,8 @@ if AIRFLOW_V_3_0_PLUS: from airflow.sdk import ObjectStoragePath - from airflow.sdk.bases.operator import BaseOperator else: from airflow.io.path import ObjectStoragePath # type: ignore[no-redef] - from airflow.models import BaseOperator # type: ignore[no-redef] class FileTransferOperator(BaseOperator): diff --git a/providers/common/io/src/airflow/providers/common/io/version_compat.py b/providers/common/io/src/airflow/providers/common/io/version_compat.py index 48d122b669696..e7a259afb357c 100644 --- a/providers/common/io/src/airflow/providers/common/io/version_compat.py +++ b/providers/common/io/src/airflow/providers/common/io/version_compat.py @@ -33,3 +33,13 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]: AIRFLOW_V_3_0_PLUS = get_base_airflow_version_tuple() >= (3, 0, 0) + +if AIRFLOW_V_3_0_PLUS: + from airflow.sdk import BaseOperator +else: + from airflow.models import BaseOperator + +__all__ = [ + "AIRFLOW_V_3_0_PLUS", + "BaseOperator", +] diff --git a/providers/common/sql/src/airflow/providers/common/sql/operators/generic_transfer.py b/providers/common/sql/src/airflow/providers/common/sql/operators/generic_transfer.py index c7839b28aec02..4e3b0a87ce027 100644 --- a/providers/common/sql/src/airflow/providers/common/sql/operators/generic_transfer.py +++ b/providers/common/sql/src/airflow/providers/common/sql/operators/generic_transfer.py @@ -23,9 +23,9 @@ from airflow.exceptions import AirflowException from airflow.hooks.base import BaseHook -from airflow.models import BaseOperator from airflow.providers.common.sql.hooks.sql import DbApiHook from airflow.providers.common.sql.triggers.sql import SQLExecuteQueryTrigger +from airflow.providers.common.sql.version_compat import BaseOperator if TYPE_CHECKING: import jinja2 @@ -192,7 +192,7 @@ def execute_complete( ) self.log.info("Offset increased to %d", offset) - self.xcom_push(context=context, key="offset", value=offset) + context["ti"].xcom_push(key="offset", value=offset) self.log.info("Inserting %d rows into %s", len(results), self.destination_conn_id) self.destination_hook.insert_rows( diff --git a/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py b/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py index af04349532a51..250a249d5af19 100644 --- a/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py +++ b/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py @@ -25,9 +25,10 @@ from airflow.exceptions import AirflowException, AirflowFailException from airflow.hooks.base import BaseHook -from airflow.models import BaseOperator, SkipMixin +from airflow.models import SkipMixin from airflow.providers.common.sql.hooks.handlers import fetch_all_handler, return_single_query_results from airflow.providers.common.sql.hooks.sql import DbApiHook +from airflow.providers.common.sql.version_compat import BaseOperator from airflow.utils.helpers import merge_dicts if TYPE_CHECKING: diff --git a/providers/common/sql/src/airflow/providers/common/sql/version_compat.py b/providers/common/sql/src/airflow/providers/common/sql/version_compat.py index 48d122b669696..b326387fea2a1 100644 --- a/providers/common/sql/src/airflow/providers/common/sql/version_compat.py +++ b/providers/common/sql/src/airflow/providers/common/sql/version_compat.py @@ -33,3 +33,15 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]: AIRFLOW_V_3_0_PLUS = get_base_airflow_version_tuple() >= (3, 0, 0) + +if AIRFLOW_V_3_0_PLUS: + from airflow.sdk import BaseOperator, BaseSensorOperator +else: + from airflow.models import BaseOperator + from airflow.sensors.base import BaseSensorOperator # type: ignore[no-redef] + +__all__ = [ + "AIRFLOW_V_3_0_PLUS", + "BaseOperator", + "BaseSensorOperator", +] diff --git a/providers/common/sql/tests/unit/common/sql/operators/test_generic_transfer.py b/providers/common/sql/tests/unit/common/sql/operators/test_generic_transfer.py index 92b118d703ae0..fe01d68d2f1b4 100644 --- a/providers/common/sql/tests/unit/common/sql/operators/test_generic_transfer.py +++ b/providers/common/sql/tests/unit/common/sql/operators/test_generic_transfer.py @@ -55,11 +55,6 @@ @pytest.mark.backend("mysql") class TestMySql: - def setup_method(self): - args = {"owner": "airflow", "start_date": DEFAULT_DATE} - dag = DAG(TEST_DAG_ID, schedule=None, default_args=args) - self.dag = dag - def teardown_method(self): from airflow.providers.mysql.hooks.mysql import MySqlHook @@ -77,7 +72,7 @@ def teardown_method(self): "mysql-connector-python", ], ) - def test_mysql_to_mysql(self, client): + def test_mysql_to_mysql(self, client, dag_maker): class MySqlContext: def __init__(self, client): self.client = client @@ -92,6 +87,25 @@ def __exit__(self, exc_type, exc_val, exc_tb): with MySqlContext(client): sql = "SELECT * FROM connection;" + with dag_maker(f"TEST_DAG_ID_{client}", start_date=DEFAULT_DATE): + op = GenericTransfer( + task_id="test_m2m", + preoperator=[ + "DROP TABLE IF EXISTS test_mysql_to_mysql", + "CREATE TABLE IF NOT EXISTS test_mysql_to_mysql LIKE connection", + ], + source_conn_id="airflow_db", + destination_conn_id="airflow_db", + destination_table="test_mysql_to_mysql", + sql=sql, + ) + + dag_maker.run_ti(op.task_id) + + @mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook.insert_rows") + def test_mysql_to_mysql_replace(self, mock_insert, dag_maker): + sql = "SELECT * FROM connection LIMIT 10;" + with dag_maker("TEST_DAG_ID", start_date=DEFAULT_DATE): op = GenericTransfer( task_id="test_m2m", preoperator=[ @@ -102,27 +116,10 @@ def __exit__(self, exc_type, exc_val, exc_tb): destination_conn_id="airflow_db", destination_table="test_mysql_to_mysql", sql=sql, - dag=self.dag, + insert_args={"replace": True}, ) - op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) - @mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook.insert_rows") - def test_mysql_to_mysql_replace(self, mock_insert): - sql = "SELECT * FROM connection LIMIT 10;" - op = GenericTransfer( - task_id="test_m2m", - preoperator=[ - "DROP TABLE IF EXISTS test_mysql_to_mysql", - "CREATE TABLE IF NOT EXISTS test_mysql_to_mysql LIKE connection", - ], - source_conn_id="airflow_db", - destination_conn_id="airflow_db", - destination_table="test_mysql_to_mysql", - sql=sql, - dag=self.dag, - insert_args={"replace": True}, - ) - op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + dag_maker.run_ti(op.task_id) assert mock_insert.called _, kwargs = mock_insert.call_args assert "replace" in kwargs @@ -140,7 +137,7 @@ def teardown_method(self): def test_postgres_to_postgres(self, dag_maker): sql = "SELECT * FROM INFORMATION_SCHEMA.TABLES LIMIT 100;" with dag_maker(default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True): - op = GenericTransfer( + _ = GenericTransfer( task_id="test_p2p", preoperator=[ "DROP TABLE IF EXISTS test_postgres_to_postgres", @@ -151,14 +148,14 @@ def test_postgres_to_postgres(self, dag_maker): destination_table="test_postgres_to_postgres", sql=sql, ) - dag_maker.create_dagrun() - op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + dr = dag_maker.create_dagrun() + dag_maker.run_ti("test_p2p", dr) @mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook.insert_rows") def test_postgres_to_postgres_replace(self, mock_insert, dag_maker): sql = "SELECT id, conn_id, conn_type FROM connection LIMIT 10;" with dag_maker(default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True): - op = GenericTransfer( + _ = GenericTransfer( task_id="test_p2p", preoperator=[ "DROP TABLE IF EXISTS test_postgres_to_postgres", @@ -174,8 +171,8 @@ def test_postgres_to_postgres_replace(self, mock_insert, dag_maker): "replace_index": "id", }, ) - dag_maker.create_dagrun() - op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + dr = dag_maker.create_dagrun() + dag_maker.run_ti("test_p2p", dr) assert mock_insert.called _, kwargs = mock_insert.call_args assert "replace" in kwargs diff --git a/providers/common/sql/tests/unit/common/sql/operators/test_sql.py b/providers/common/sql/tests/unit/common/sql/operators/test_sql.py index c576037630440..b3e02f8d7f8ce 100644 --- a/providers/common/sql/tests/unit/common/sql/operators/test_sql.py +++ b/providers/common/sql/tests/unit/common/sql/operators/test_sql.py @@ -1095,6 +1095,30 @@ def setup_method(self): self.branch_2 = EmptyOperator(task_id="branch_2", dag=self.dag) self.branch_3 = None + def get_ti(self, task_id, dr=None): + if dr is None: + if AIRFLOW_V_3_0_PLUS: + dagrun_kwargs = { + "logical_date": DEFAULT_DATE, + "run_after": DEFAULT_DATE, + "triggered_by": DagRunTriggeredByType.TEST, + } + else: + dagrun_kwargs = {"execution_date": DEFAULT_DATE} + dr = self.dag.create_dagrun( + run_id=f"manual__{timezone.utcnow().isoformat()}", + run_type=DagRunType.MANUAL, + start_date=timezone.utcnow(), + state=State.RUNNING, + data_interval=(DEFAULT_DATE, DEFAULT_DATE), + **dagrun_kwargs, + ) + + ti = dr.get_task_instance(task_id) + ti.task = self.dag.get_task(ti.task_id) + + return ti + def teardown_method(self): with create_session() as session: session.query(DagRun).delete() @@ -1124,7 +1148,7 @@ def test_unsupported_conn_type(self): ) with pytest.raises(AirflowException): - op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + op.execute({}) def test_invalid_conn(self): """Check if BranchSQLOperator throws an exception for invalid connection""" @@ -1138,7 +1162,7 @@ def test_invalid_conn(self): ) with pytest.raises(AirflowException): - op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + op.execute({}) def test_invalid_follow_task_true(self): """Check if BranchSQLOperator throws an exception for invalid connection""" @@ -1152,7 +1176,7 @@ def test_invalid_follow_task_true(self): ) with pytest.raises(AirflowException): - op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + op.execute({}) def test_invalid_follow_task_false(self): """Check if BranchSQLOperator throws an exception for invalid connection""" @@ -1166,12 +1190,13 @@ def test_invalid_follow_task_false(self): ) with pytest.raises(AirflowException): - op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + op.execute({}) @pytest.mark.backend("mysql") def test_sql_branch_operator_mysql(self, branch_op): """Check if BranchSQLOperator works with backend""" - branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + + branch_op.execute({"ti": mock.MagicMock(task=branch_op)}) @pytest.mark.backend("postgres") def test_sql_branch_operator_postgres(self): @@ -1184,7 +1209,7 @@ def test_sql_branch_operator_postgres(self): follow_task_ids_if_false=["branch_2"], dag=self.dag, ) - branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + self.get_ti(branch_op.task_id).run() @mock.patch("airflow.providers.common.sql.operators.sql.BaseSQLOperator.get_db_hook") def test_branch_single_value_with_dag_run(self, mock_get_db_hook, branch_op): @@ -1223,8 +1248,9 @@ def test_branch_single_value_with_dag_run(self, mock_get_db_hook, branch_op): assert exc_info.value.tasks == [("branch_2", -1)] else: - branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + self.get_ti(branch_op.task_id, dr).run() tis = dr.get_task_instances() + for ti in tis: if ti.task_id == "make_choice": assert ti.state == State.SUCCESS @@ -1267,11 +1293,11 @@ def test_branch_true_with_dag_run(self, mock_get_db_hook, true_value, branch_op) from airflow.exceptions import DownstreamTasksSkipped with pytest.raises(DownstreamTasksSkipped) as exc_info: - branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + branch_op.execute({}) assert exc_info.value.tasks == [("branch_2", -1)] else: - branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + self.get_ti(branch_op.task_id, dr).run() tis = dr.get_task_instances() for ti in tis: if ti.task_id == "make_choice": @@ -1315,11 +1341,12 @@ def test_branch_false_with_dag_run(self, mock_get_db_hook, false_value, branch_o from airflow.exceptions import DownstreamTasksSkipped with pytest.raises(DownstreamTasksSkipped) as exc_info: - branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + branch_op.execute({}) assert exc_info.value.tasks == [("branch_1", -1)] else: - branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + self.get_ti(branch_op.task_id, dr).run() tis = dr.get_task_instances() + for ti in tis: if ti.task_id == "make_choice": assert ti.state == State.SUCCESS @@ -1375,7 +1402,7 @@ def test_branch_list_with_dag_run(self, mock_get_db_hook): branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) assert exc_info.value.tasks == [("branch_3", -1)] else: - branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + self.get_ti(branch_op.task_id, dr).run() tis = dr.get_task_instances() for ti in tis: if ti.task_id == "make_choice": @@ -1416,7 +1443,7 @@ def test_invalid_query_result_with_dag_run(self, mock_get_db_hook, branch_op): mock_get_records.return_value = ["Invalid Value"] with pytest.raises(AirflowException): - branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + branch_op.execute({}) @mock.patch("airflow.providers.common.sql.operators.sql.BaseSQLOperator.get_db_hook") def test_with_skip_in_branch_downstream_dependencies(self, mock_get_db_hook, branch_op): @@ -1447,7 +1474,7 @@ def test_with_skip_in_branch_downstream_dependencies(self, mock_get_db_hook, bra for true_value in SUPPORTED_TRUE_VALUES: mock_get_records.return_value = [true_value] - branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + self.get_ti(branch_op.task_id, dr).run() tis = dr.get_task_instances() for ti in tis: @@ -1493,7 +1520,7 @@ def test_with_skip_in_branch_downstream_dependencies2(self, mock_get_db_hook, fa branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) assert exc_info.value.tasks == [("branch_1", -1)] else: - branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + self.get_ti(branch_op.task_id, dr).run() tis = dr.get_task_instances() for ti in tis: if ti.task_id == "make_choice": diff --git a/providers/snowflake/tests/unit/snowflake/operators/test_snowflake.py b/providers/snowflake/tests/unit/snowflake/operators/test_snowflake.py index b61b699774e78..9f52f80a4eb41 100644 --- a/providers/snowflake/tests/unit/snowflake/operators/test_snowflake.py +++ b/providers/snowflake/tests/unit/snowflake/operators/test_snowflake.py @@ -58,23 +58,20 @@ @pytest.mark.db_test class TestSnowflakeOperator: - def setup_method(self): - args = {"owner": "airflow", "start_date": DEFAULT_DATE} - dag = DAG(TEST_DAG_ID, schedule=None, default_args=args) - self.dag = dag - @mock.patch("airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator.get_db_hook") - def test_snowflake_operator(self, mock_get_db_hook): + def test_snowflake_operator(self, mock_get_db_hook, dag_maker): sql = """ CREATE TABLE IF NOT EXISTS test_airflow ( dummy VARCHAR(50) ); """ - operator = SQLExecuteQueryOperator( - task_id="basic_snowflake", sql=sql, dag=self.dag, do_xcom_push=False, conn_id="snowflake_default" - ) + + with dag_maker(TEST_DAG_ID): + operator = SQLExecuteQueryOperator( + task_id="basic_snowflake", sql=sql, do_xcom_push=False, conn_id="snowflake_default" + ) # do_xcom_push=False because otherwise the XCom test will fail due to the mocking (it actually works) - operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + dag_maker.run_ti(operator.task_id) class TestSnowflakeOperatorForParams: