Skip to content

Commit d944eb0

Browse files
Check permissions for ImportError (#37468)
1 parent 16d2671 commit d944eb0

File tree

4 files changed

+314
-21
lines changed

4 files changed

+314
-21
lines changed

airflow/api_connexion/endpoints/import_error_endpoint.py

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,39 +16,59 @@
1616
# under the License.
1717
from __future__ import annotations
1818

19-
from typing import TYPE_CHECKING
19+
from typing import TYPE_CHECKING, Sequence
2020

2121
from sqlalchemy import func, select
2222

2323
from airflow.api_connexion import security
24-
from airflow.api_connexion.exceptions import NotFound
24+
from airflow.api_connexion.exceptions import NotFound, PermissionDenied
2525
from airflow.api_connexion.parameters import apply_sorting, check_limit, format_parameters
2626
from airflow.api_connexion.schemas.error_schema import (
2727
ImportErrorCollection,
2828
import_error_collection_schema,
2929
import_error_schema,
3030
)
31-
from airflow.auth.managers.models.resource_details import AccessView
31+
from airflow.auth.managers.models.resource_details import AccessView, DagDetails
32+
from airflow.models.dag import DagModel
3233
from airflow.models.errors import ImportError as ImportErrorModel
3334
from airflow.utils.session import NEW_SESSION, provide_session
35+
from airflow.www.extensions.init_auth_manager import get_auth_manager
3436

3537
if TYPE_CHECKING:
3638
from sqlalchemy.orm import Session
3739

3840
from airflow.api_connexion.types import APIResponse
41+
from airflow.auth.managers.models.batch_apis import IsAuthorizedDagRequest
3942

4043

4144
@security.requires_access_view(AccessView.IMPORT_ERRORS)
4245
@provide_session
4346
def get_import_error(*, import_error_id: int, session: Session = NEW_SESSION) -> APIResponse:
4447
"""Get an import error."""
4548
error = session.get(ImportErrorModel, import_error_id)
46-
4749
if error is None:
4850
raise NotFound(
4951
"Import error not found",
5052
detail=f"The ImportError with import_error_id: `{import_error_id}` was not found",
5153
)
54+
session.expunge(error)
55+
56+
can_read_all_dags = get_auth_manager().is_authorized_dag(method="GET")
57+
if not can_read_all_dags:
58+
readable_dag_ids = security.get_readable_dags()
59+
file_dag_ids = {
60+
dag_id[0]
61+
for dag_id in session.query(DagModel.dag_id).filter(DagModel.fileloc == error.filename).all()
62+
}
63+
64+
# Can the user read any DAGs in the file?
65+
if not readable_dag_ids.intersection(file_dag_ids):
66+
raise PermissionDenied(detail="You do not have read permission on any of the DAGs in the file")
67+
68+
# Check if user has read access to all the DAGs defined in the file
69+
if not file_dag_ids.issubset(readable_dag_ids):
70+
error.stacktrace = "REDACTED - you do not have read permission on all DAGs in the file"
71+
5272
return import_error_schema.dump(error)
5373

5474

@@ -65,10 +85,41 @@ def get_import_errors(
6585
"""Get all import errors."""
6686
to_replace = {"import_error_id": "id"}
6787
allowed_filter_attrs = ["import_error_id", "timestamp", "filename"]
68-
total_entries = session.scalars(func.count(ImportErrorModel.id)).one()
88+
count_query = select(func.count(ImportErrorModel.id))
6989
query = select(ImportErrorModel)
7090
query = apply_sorting(query, order_by, to_replace, allowed_filter_attrs)
91+
92+
can_read_all_dags = get_auth_manager().is_authorized_dag(method="GET")
93+
94+
if not can_read_all_dags:
95+
# if the user doesn't have access to all DAGs, only display errors from visible DAGs
96+
readable_dag_ids = security.get_readable_dags()
97+
dagfiles_subq = (
98+
select(DagModel.fileloc).distinct().where(DagModel.dag_id.in_(readable_dag_ids)).subquery()
99+
)
100+
query = query.where(ImportErrorModel.filename.in_(dagfiles_subq))
101+
count_query = count_query.where(ImportErrorModel.filename.in_(dagfiles_subq))
102+
103+
total_entries = session.scalars(count_query).one()
71104
import_errors = session.scalars(query.offset(offset).limit(limit)).all()
105+
106+
if not can_read_all_dags:
107+
for import_error in import_errors:
108+
# Check if user has read access to all the DAGs defined in the file
109+
file_dag_ids = (
110+
session.query(DagModel.dag_id).filter(DagModel.fileloc == import_error.filename).all()
111+
)
112+
requests: Sequence[IsAuthorizedDagRequest] = [
113+
{
114+
"method": "GET",
115+
"details": DagDetails(id=dag_id[0]),
116+
}
117+
for dag_id in file_dag_ids
118+
]
119+
if not get_auth_manager().batch_is_authorized_dag(requests):
120+
session.expunge(import_error)
121+
import_error.stacktrace = "REDACTED - you do not have read permission on all DAGs in the file"
122+
72123
return import_error_collection_schema.dump(
73124
ImportErrorCollection(import_errors=import_errors, total_entries=total_entries)
74125
)

airflow/www/views.py

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@
147147
if TYPE_CHECKING:
148148
from sqlalchemy.orm import Session
149149

150+
from airflow.auth.managers.models.batch_apis import IsAuthorizedDagRequest
150151
from airflow.models.dag import DAG
151152
from airflow.models.operator import Operator
152153

@@ -935,20 +936,44 @@ def index(self):
935936

936937
owner_links_dict = DagOwnerAttributes.get_all(session)
937938

938-
import_errors = select(errors.ImportError).order_by(errors.ImportError.id)
939-
940-
if not get_auth_manager().is_authorized_dag(method="GET"):
941-
# if the user doesn't have access to all DAGs, only display errors from visible DAGs
942-
import_errors = import_errors.join(
943-
DagModel, DagModel.fileloc == errors.ImportError.filename
944-
).where(DagModel.dag_id.in_(filter_dag_ids))
939+
if get_auth_manager().is_authorized_view(access_view=AccessView.IMPORT_ERRORS):
940+
import_errors = select(errors.ImportError).order_by(errors.ImportError.id)
941+
942+
can_read_all_dags = get_auth_manager().is_authorized_dag(method="GET")
943+
if not can_read_all_dags:
944+
# if the user doesn't have access to all DAGs, only display errors from visible DAGs
945+
import_errors = import_errors.where(
946+
errors.ImportError.filename.in_(
947+
select(DagModel.fileloc)
948+
.distinct()
949+
.where(DagModel.dag_id.in_(filter_dag_ids))
950+
.subquery()
951+
)
952+
)
945953

946-
import_errors = session.scalars(import_errors)
947-
for import_error in import_errors:
948-
flash(
949-
f"Broken DAG: [{import_error.filename}] {import_error.stacktrace}",
950-
"dag_import_error",
951-
)
954+
import_errors = session.scalars(import_errors)
955+
for import_error in import_errors:
956+
stacktrace = import_error.stacktrace
957+
if not can_read_all_dags:
958+
# Check if user has read access to all the DAGs defined in the file
959+
file_dag_ids = (
960+
session.query(DagModel.dag_id)
961+
.filter(DagModel.fileloc == import_error.filename)
962+
.all()
963+
)
964+
requests: Sequence[IsAuthorizedDagRequest] = [
965+
{
966+
"method": "GET",
967+
"details": DagDetails(id=dag_id[0]),
968+
}
969+
for dag_id in file_dag_ids
970+
]
971+
if not get_auth_manager().batch_is_authorized_dag(requests):
972+
stacktrace = "REDACTED - you do not have read permission on all DAGs in the file"
973+
flash(
974+
f"Broken DAG: [{import_error.filename}]\r{stacktrace}",
975+
"dag_import_error",
976+
)
952977

953978
from airflow.plugins_manager import import_errors as plugin_import_errors
954979

tests/api_connexion/endpoints/test_import_error_endpoint.py

Lines changed: 159 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,19 @@
2121
import pytest
2222

2323
from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP
24+
from airflow.models.dag import DagModel
2425
from airflow.models.errors import ImportError
2526
from airflow.security import permissions
2627
from airflow.utils import timezone
2728
from airflow.utils.session import provide_session
2829
from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user
2930
from tests.test_utils.config import conf_vars
30-
from tests.test_utils.db import clear_db_import_errors
31+
from tests.test_utils.db import clear_db_dags, clear_db_import_errors
3132

3233
pytestmark = pytest.mark.db_test
3334

35+
TEST_DAG_IDS = ["test_dag", "test_dag2"]
36+
3437

3538
@pytest.fixture(scope="module")
3639
def configured_app(minimal_app_for_api):
@@ -39,14 +42,34 @@ def configured_app(minimal_app_for_api):
3942
app, # type:ignore
4043
username="test",
4144
role_name="Test",
42-
permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_IMPORT_ERROR)], # type: ignore
45+
permissions=[
46+
(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
47+
(permissions.ACTION_CAN_READ, permissions.RESOURCE_IMPORT_ERROR),
48+
], # type: ignore
4349
)
4450
create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore
51+
create_user(
52+
app, # type:ignore
53+
username="test_single_dag",
54+
role_name="TestSingleDAG",
55+
permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_IMPORT_ERROR)], # type: ignore
56+
)
57+
# For some reason, DAG level permissions are not synced when in the above list of perms,
58+
# so do it manually here:
59+
app.appbuilder.sm.bulk_sync_roles(
60+
[
61+
{
62+
"role": "TestSingleDAG",
63+
"perms": [(permissions.ACTION_CAN_READ, permissions.resource_name_for_dag(TEST_DAG_IDS[0]))],
64+
}
65+
]
66+
)
4567

46-
yield minimal_app_for_api
68+
yield app
4769

4870
delete_user(app, username="test") # type: ignore
4971
delete_user(app, username="test_no_permissions") # type: ignore
72+
delete_user(app, username="test_single_dag") # type: ignore
5073

5174

5275
class TestBaseImportError:
@@ -58,9 +81,11 @@ def setup_attrs(self, configured_app) -> None:
5881
self.client = self.app.test_client() # type:ignore
5982

6083
clear_db_import_errors()
84+
clear_db_dags()
6185

6286
def teardown_method(self) -> None:
6387
clear_db_import_errors()
88+
clear_db_dags()
6489

6590
@staticmethod
6691
def _normalize_import_errors(import_errors):
@@ -121,6 +146,72 @@ def test_should_raise_403_forbidden(self):
121146
)
122147
assert response.status_code == 403
123148

149+
def test_should_raise_403_forbidden_without_dag_read(self, session):
150+
import_error = ImportError(
151+
filename="Lorem_ipsum.py",
152+
stacktrace="Lorem ipsum",
153+
timestamp=timezone.parse(self.timestamp, timezone="UTC"),
154+
)
155+
session.add(import_error)
156+
session.commit()
157+
158+
response = self.client.get(
159+
f"/api/v1/importErrors/{import_error.id}", environ_overrides={"REMOTE_USER": "test_single_dag"}
160+
)
161+
162+
assert response.status_code == 403
163+
164+
def test_should_return_200_with_single_dag_read(self, session):
165+
dag_model = DagModel(dag_id=TEST_DAG_IDS[0], fileloc="Lorem_ipsum.py")
166+
session.add(dag_model)
167+
import_error = ImportError(
168+
filename="Lorem_ipsum.py",
169+
stacktrace="Lorem ipsum",
170+
timestamp=timezone.parse(self.timestamp, timezone="UTC"),
171+
)
172+
session.add(import_error)
173+
session.commit()
174+
175+
response = self.client.get(
176+
f"/api/v1/importErrors/{import_error.id}", environ_overrides={"REMOTE_USER": "test_single_dag"}
177+
)
178+
179+
assert response.status_code == 200
180+
response_data = response.json
181+
response_data["import_error_id"] = 1
182+
assert {
183+
"filename": "Lorem_ipsum.py",
184+
"import_error_id": 1,
185+
"stack_trace": "Lorem ipsum",
186+
"timestamp": "2020-06-10T12:00:00+00:00",
187+
} == response_data
188+
189+
def test_should_return_200_redacted_with_single_dag_read_in_dagfile(self, session):
190+
for dag_id in TEST_DAG_IDS:
191+
dag_model = DagModel(dag_id=dag_id, fileloc="Lorem_ipsum.py")
192+
session.add(dag_model)
193+
import_error = ImportError(
194+
filename="Lorem_ipsum.py",
195+
stacktrace="Lorem ipsum",
196+
timestamp=timezone.parse(self.timestamp, timezone="UTC"),
197+
)
198+
session.add(import_error)
199+
session.commit()
200+
201+
response = self.client.get(
202+
f"/api/v1/importErrors/{import_error.id}", environ_overrides={"REMOTE_USER": "test_single_dag"}
203+
)
204+
205+
assert response.status_code == 200
206+
response_data = response.json
207+
response_data["import_error_id"] = 1
208+
assert {
209+
"filename": "Lorem_ipsum.py",
210+
"import_error_id": 1,
211+
"stack_trace": "REDACTED - you do not have read permission on all DAGs in the file",
212+
"timestamp": "2020-06-10T12:00:00+00:00",
213+
} == response_data
214+
124215

125216
class TestGetImportErrorsEndpoint(TestBaseImportError):
126217
def test_get_import_errors(self, session):
@@ -231,6 +322,71 @@ def test_should_raises_401_unauthenticated(self, session):
231322

232323
assert_401(response)
233324

325+
def test_get_import_errors_single_dag(self, session):
326+
for dag_id in TEST_DAG_IDS:
327+
fake_filename = f"/tmp/{dag_id}.py"
328+
dag_model = DagModel(dag_id=dag_id, fileloc=fake_filename)
329+
session.add(dag_model)
330+
importerror = ImportError(
331+
filename=fake_filename,
332+
stacktrace="Lorem ipsum",
333+
timestamp=timezone.parse(self.timestamp, timezone="UTC"),
334+
)
335+
session.add(importerror)
336+
session.commit()
337+
338+
response = self.client.get(
339+
"/api/v1/importErrors", environ_overrides={"REMOTE_USER": "test_single_dag"}
340+
)
341+
342+
assert response.status_code == 200
343+
response_data = response.json
344+
self._normalize_import_errors(response_data["import_errors"])
345+
assert {
346+
"import_errors": [
347+
{
348+
"filename": "/tmp/test_dag.py",
349+
"import_error_id": 1,
350+
"stack_trace": "Lorem ipsum",
351+
"timestamp": "2020-06-10T12:00:00+00:00",
352+
},
353+
],
354+
"total_entries": 1,
355+
} == response_data
356+
357+
def test_get_import_errors_single_dag_in_dagfile(self, session):
358+
for dag_id in TEST_DAG_IDS:
359+
fake_filename = "/tmp/all_in_one.py"
360+
dag_model = DagModel(dag_id=dag_id, fileloc=fake_filename)
361+
session.add(dag_model)
362+
363+
importerror = ImportError(
364+
filename="/tmp/all_in_one.py",
365+
stacktrace="Lorem ipsum",
366+
timestamp=timezone.parse(self.timestamp, timezone="UTC"),
367+
)
368+
session.add(importerror)
369+
session.commit()
370+
371+
response = self.client.get(
372+
"/api/v1/importErrors", environ_overrides={"REMOTE_USER": "test_single_dag"}
373+
)
374+
375+
assert response.status_code == 200
376+
response_data = response.json
377+
self._normalize_import_errors(response_data["import_errors"])
378+
assert {
379+
"import_errors": [
380+
{
381+
"filename": "/tmp/all_in_one.py",
382+
"import_error_id": 1,
383+
"stack_trace": "REDACTED - you do not have read permission on all DAGs in the file",
384+
"timestamp": "2020-06-10T12:00:00+00:00",
385+
},
386+
],
387+
"total_entries": 1,
388+
} == response_data
389+
234390

235391
class TestGetImportErrorsEndpointPagination(TestBaseImportError):
236392
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)