Skip to content

Commit 417698f

Browse files
authored
PysparkOperator and not only a decorator (#60041)
* feat: PysparkOperator and not only a decorator * clean * clean2 * remove sc arg --------- Co-authored-by: raphaelauv <raphaelauv@users.noreply.github.com>
1 parent 5496e15 commit 417698f

9 files changed

Lines changed: 197 additions & 75 deletions

File tree

providers/apache/spark/docs/decorators/pyspark.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ Example
4242
-------
4343

4444
The following example shows how to use the ``@task.pyspark`` decorator. Note
45-
that the ``spark`` and ``sc`` objects are injected into the function.
45+
that the ``spark`` object is injected into the function.
4646

4747
.. exampleinclude:: /../tests/system/apache/spark/example_pyspark.py
4848
:language: python

providers/apache/spark/docs/operators.rst

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ Prerequisite
2929
and :doc:`JDBC connection <apache-airflow-providers-jdbc:connections/jdbc>`.
3030
* :class:`~airflow.providers.apache.spark.operators.spark_sql.SparkSqlOperator`
3131
gets all the configurations from operator parameters.
32+
* To use :class:`~airflow.providers.apache.spark.operators.spark_pyspark.PySparkOperator`
33+
you can configure :doc:`SparkConnect Connection <connections/spark-connect>`.
3234

3335
.. _howto/operator:SparkJDBCOperator:
3436

@@ -56,6 +58,29 @@ Reference
5658

5759
For further information, look at `Apache Spark DataFrameWriter documentation <https://spark.apache.org/docs/2.4.5/api/scala/index.html#org.apache.spark.sql.DataFrameWriter>`_.
5860

61+
.. _howto/operator:PySparkOperator:
62+
63+
PySparkOperator
64+
----------------
65+
66+
Launches applications on a Apache Spark Connect server or directly in a standalone mode
67+
68+
For parameter definition take a look at :class:`~airflow.providers.apache.spark.operators.spark_pyspark.PySparkOperator`.
69+
70+
Using the operator
71+
""""""""""""""""""
72+
73+
.. exampleinclude:: /../tests/system/apache/spark/example_spark_dag.py
74+
:language: python
75+
:dedent: 4
76+
:start-after: [START howto_operator_spark_pyspark]
77+
:end-before: [END howto_operator_spark_pyspark]
78+
79+
Reference
80+
"""""""""
81+
82+
For further information, look at `Running the Spark Connect Python <https://spark.apache.org/docs/latest/api/python/getting_started/quickstart_connect.html>`_.
83+
5984
.. _howto/operator:SparkSqlOperator:
6085

6186
SparkSqlOperator

providers/apache/spark/provider.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ operators:
9494
- airflow.providers.apache.spark.operators.spark_jdbc
9595
- airflow.providers.apache.spark.operators.spark_sql
9696
- airflow.providers.apache.spark.operators.spark_submit
97+
- airflow.providers.apache.spark.operators.spark_pyspark
9798

9899
hooks:
99100
- integration-name: Apache Spark

providers/apache/spark/src/airflow/providers/apache/spark/decorators/pyspark.py

Lines changed: 13 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -19,38 +19,33 @@
1919

2020
import inspect
2121
from collections.abc import Callable, Sequence
22-
from typing import TYPE_CHECKING, Any
2322

24-
from airflow.providers.apache.spark.hooks.spark_connect import SparkConnectHook
23+
from airflow.providers.apache.spark.operators.spark_pyspark import SPARK_CONTEXT_KEYS, PySparkOperator
2524
from airflow.providers.common.compat.sdk import (
26-
BaseHook,
2725
DecoratedOperator,
2826
TaskDecorator,
2927
task_decorator_factory,
3028
)
31-
from airflow.providers.common.compat.standard.operators import PythonOperator
3229

33-
if TYPE_CHECKING:
34-
from airflow.providers.common.compat.sdk import Context
35-
SPARK_CONTEXT_KEYS = ["spark", "sc"]
3630

37-
38-
class _PySparkDecoratedOperator(DecoratedOperator, PythonOperator):
31+
class _PySparkDecoratedOperator(DecoratedOperator, PySparkOperator):
3932
custom_operator_name = "@task.pyspark"
4033

41-
template_fields: Sequence[str] = ("op_args", "op_kwargs")
42-
4334
def __init__(
4435
self,
36+
*,
4537
python_callable: Callable,
46-
op_args: Sequence | None = None,
47-
op_kwargs: dict | None = None,
4838
conn_id: str | None = None,
4939
config_kwargs: dict | None = None,
40+
op_args: Sequence | None = None,
41+
op_kwargs: dict | None = None,
5042
**kwargs,
51-
):
52-
self.conn_id = conn_id
53-
self.config_kwargs = config_kwargs or {}
43+
) -> None:
44+
kwargs_to_upstream = {
45+
"python_callable": python_callable,
46+
"op_args": op_args,
47+
"op_kwargs": op_kwargs,
48+
}
5449

5550
signature = inspect.signature(python_callable)
5651
parameters = [
@@ -61,65 +56,16 @@ def __init__(
6156
# see https://github.com/python/mypy/issues/12472
6257
python_callable.__signature__ = signature.replace(parameters=parameters) # type: ignore[attr-defined]
6358

64-
kwargs_to_upstream = {
65-
"python_callable": python_callable,
66-
"op_args": op_args,
67-
"op_kwargs": op_kwargs,
68-
}
6959
super().__init__(
7060
kwargs_to_upstream=kwargs_to_upstream,
7161
python_callable=python_callable,
62+
config_kwargs=config_kwargs,
63+
conn_id=conn_id,
7264
op_args=op_args,
7365
op_kwargs=op_kwargs,
7466
**kwargs,
7567
)
7668

77-
def execute(self, context: Context):
78-
from pyspark import SparkConf
79-
from pyspark.sql import SparkSession
80-
81-
conf = SparkConf()
82-
conf.set("spark.app.name", f"{self.dag_id}-{self.task_id}")
83-
84-
url = "local[*]"
85-
if self.conn_id:
86-
# we handle both spark connect and spark standalone
87-
conn = BaseHook.get_connection(self.conn_id)
88-
if conn.conn_type == SparkConnectHook.conn_type:
89-
url = SparkConnectHook(self.conn_id).get_connection_url()
90-
elif conn.port:
91-
url = f"{conn.host}:{conn.port}"
92-
elif conn.host:
93-
url = conn.host
94-
95-
for key, value in conn.extra_dejson.items():
96-
conf.set(key, value)
97-
98-
# you cannot have both remote and master
99-
if url.startswith("sc://"):
100-
conf.set("spark.remote", url)
101-
102-
# task can override connection config
103-
for key, value in self.config_kwargs.items():
104-
conf.set(key, value)
105-
106-
if not conf.get("spark.remote") and not conf.get("spark.master"):
107-
conf.set("spark.master", url)
108-
109-
spark = SparkSession.builder.config(conf=conf).getOrCreate()
110-
111-
if not self.op_kwargs:
112-
self.op_kwargs = {}
113-
114-
op_kwargs: dict[str, Any] = dict(self.op_kwargs)
115-
op_kwargs["spark"] = spark
116-
117-
# spark context is not available when using spark connect
118-
op_kwargs["sc"] = spark.sparkContext if not conf.get("spark.remote") else None
119-
120-
self.op_kwargs = op_kwargs
121-
return super().execute(context)
122-
12369

12470
def pyspark_task(
12571
python_callable: Callable | None = None,

providers/apache/spark/src/airflow/providers/apache/spark/get_provider_info.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def get_provider_info():
4242
"airflow.providers.apache.spark.operators.spark_jdbc",
4343
"airflow.providers.apache.spark.operators.spark_sql",
4444
"airflow.providers.apache.spark.operators.spark_submit",
45+
"airflow.providers.apache.spark.operators.spark_pyspark",
4546
],
4647
}
4748
],
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing,
13+
# software distributed under the License is distributed on an
14+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
# KIND, either express or implied. See the License for the
16+
# specific language governing permissions and limitations
17+
# under the License.
18+
from __future__ import annotations
19+
20+
import inspect
21+
from collections.abc import Callable, Sequence
22+
23+
from airflow.providers.apache.spark.hooks.spark_connect import SparkConnectHook
24+
from airflow.providers.common.compat.sdk import BaseHook
25+
from airflow.providers.common.compat.standard.operators import PythonOperator
26+
27+
SPARK_CONTEXT_KEYS = ["spark", "sc"]
28+
29+
30+
class PySparkOperator(PythonOperator):
31+
"""Submit the run of a pyspark job to an external spark-connect service or directly run the pyspark job in a standalone mode."""
32+
33+
template_fields: Sequence[str] = ("conn_id", "config_kwargs", *PythonOperator.template_fields)
34+
35+
def __init__(
36+
self,
37+
python_callable: Callable,
38+
conn_id: str | None = None,
39+
config_kwargs: dict | None = None,
40+
**kwargs,
41+
):
42+
self.conn_id = conn_id
43+
self.config_kwargs = config_kwargs or {}
44+
45+
signature = inspect.signature(python_callable)
46+
parameters = [
47+
param.replace(default=None) if param.name in SPARK_CONTEXT_KEYS else param
48+
for param in signature.parameters.values()
49+
]
50+
# mypy does not understand __signature__ attribute
51+
# see https://github.com/python/mypy/issues/12472
52+
python_callable.__signature__ = signature.replace(parameters=parameters) # type: ignore[attr-defined]
53+
54+
super().__init__(
55+
python_callable=python_callable,
56+
**kwargs,
57+
)
58+
59+
def execute_callable(self):
60+
from pyspark import SparkConf
61+
from pyspark.sql import SparkSession
62+
63+
conf = SparkConf()
64+
conf.set("spark.app.name", f"{self.dag_id}-{self.task_id}")
65+
66+
url = "local[*]"
67+
if self.conn_id:
68+
# we handle both spark connect and spark standalone
69+
conn = BaseHook.get_connection(self.conn_id)
70+
if conn.conn_type == SparkConnectHook.conn_type:
71+
url = SparkConnectHook(self.conn_id).get_connection_url()
72+
elif conn.port:
73+
url = f"{conn.host}:{conn.port}"
74+
elif conn.host:
75+
url = conn.host
76+
77+
for key, value in conn.extra_dejson.items():
78+
conf.set(key, value)
79+
80+
# you cannot have both remote and master
81+
if url.startswith("sc://"):
82+
conf.set("spark.remote", url)
83+
84+
# task can override connection config
85+
for key, value in self.config_kwargs.items():
86+
conf.set(key, value)
87+
88+
if not conf.get("spark.remote") and not conf.get("spark.master"):
89+
conf.set("spark.master", url)
90+
91+
spark_session = SparkSession.builder.config(conf=conf).getOrCreate()
92+
93+
try:
94+
self.op_kwargs = {**self.op_kwargs, "spark": spark_session}
95+
return super().execute_callable()
96+
finally:
97+
spark_session.stop()

providers/apache/spark/tests/system/apache/spark/example_spark_dag.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
from airflow.models import DAG
2929
from airflow.providers.apache.spark.operators.spark_jdbc import SparkJDBCOperator
30+
from airflow.providers.apache.spark.operators.spark_pyspark import PySparkOperator
3031
from airflow.providers.apache.spark.operators.spark_sql import SparkSqlOperator
3132
from airflow.providers.apache.spark.operators.spark_submit import SparkSubmitOperator
3233

@@ -75,6 +76,16 @@
7576
)
7677
# [END howto_operator_spark_sql]
7778

79+
# [START howto_operator_spark_pyspark]
80+
def my_pyspark_job(spark):
81+
df = spark.range(100).filter("id % 2 = 0")
82+
print(df.count())
83+
84+
spark_pyspark_job = PySparkOperator(
85+
python_callable=my_pyspark_job, conn_id="spark_connect", task_id="spark_pyspark_job"
86+
)
87+
# [END howto_operator_spark_pyspark]
88+
7889
from tests_common.test_utils.system_tests import get_test_run # noqa: E402
7990

8091
# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest)

providers/apache/spark/tests/unit/apache/spark/decorators/test_pyspark.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,10 @@ def test_pyspark_decorator_with_connection(self, spark_mock, conf_mock, dag_make
100100
conf_mock.return_value = config
101101

102102
@task.pyspark(conn_id="pyspark_local", config_kwargs={"spark.executor.memory": "2g"})
103-
def f(spark, sc):
103+
def f(spark):
104104
import random
105105

106106
assert spark is not None
107-
assert sc is not None
108107
return [random.random() for _ in range(100)]
109108

110109
with dag_maker():
@@ -129,7 +128,7 @@ def test_simple_pyspark_decorator(self, spark_mock, conf_mock, dag_maker):
129128
e = 2
130129

131130
@task.pyspark
132-
def f():
131+
def f(spark):
133132
return e
134133

135134
with dag_maker():
@@ -148,9 +147,8 @@ def test_spark_connect(self, spark_mock, conf_mock, dag_maker):
148147
conf_mock.return_value = config
149148

150149
@task.pyspark(conn_id="spark-connect")
151-
def f(spark, sc):
150+
def f(spark):
152151
assert spark is not None
153-
assert sc is None
154152

155153
return True
156154

@@ -172,9 +170,8 @@ def test_spark_connect_auth(self, spark_mock, conf_mock, dag_maker):
172170
conf_mock.return_value = config
173171

174172
@task.pyspark(conn_id="spark-connect-auth")
175-
def f(spark, sc):
173+
def f(spark):
176174
assert spark is not None
177-
assert sc is None
178175

179176
return True
180177

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing,
13+
# software distributed under the License is distributed on an
14+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
# KIND, either express or implied. See the License for the
16+
# specific language governing permissions and limitations
17+
# under the License.
18+
from __future__ import annotations
19+
20+
from airflow.models.dag import DAG
21+
from airflow.providers.apache.spark.operators.spark_pyspark import PySparkOperator
22+
from airflow.utils import timezone
23+
24+
DEFAULT_DATE = timezone.datetime(2024, 2, 1, tzinfo=timezone.utc)
25+
26+
27+
class TestSparkPySparkOperator:
28+
_config = {
29+
"conn_id": "spark_special_conn_id",
30+
}
31+
32+
def setup_method(self):
33+
args = {"owner": "airflow", "start_date": DEFAULT_DATE}
34+
self.dag = DAG("test_dag_id", schedule=None, default_args=args)
35+
36+
def test_execute(self):
37+
def my_spark_fn(spark):
38+
pass
39+
40+
operator = PySparkOperator(
41+
task_id="spark_pyspark_job", python_callable=my_spark_fn, dag=self.dag, **self._config
42+
)
43+
44+
assert self._config["conn_id"] == operator.conn_id

0 commit comments

Comments
 (0)