Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 39 additions & 29 deletions authx/_internal/_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,24 @@
class _ErrorHandler:
"""Base Handler for FastAPI handling AuthX exceptions"""

MSG_DEFAULT = "AuthX Error"
MSG_TOKEN_ERROR = "Token Error"
MSG_MISSING_TOKEN_ERROR = "Missing JWT in request"
MSG_MISSING_CSRF_ERROR = "Missing CSRF double submit token in request"
MSG_TOKEN_TYPE_ERROR = "Bad token type"
MSG_REVOKED_TOKEN_ERROR = "Invalid token"
MSG_TOKEN_REQUIRED_ERROR = "Token required"
MSG_FRESH_TOKEN_REQUIRED_ERROR = "Fresh token required"
MSG_ACCESS_TOKEN_REQUIRED_ERROR = "Access token required"
MSG_REFRESH_TOKEN_REQUIRED_ERROR = "Refresh token required"
MSG_CSRF_ERROR = "CSRF double submit does not match"
MSG_DECODE_JWT_ERROR = "Invalid Token"
MSG_TokenError = "Token Error"
MSG_MissingTokenError = "Missing JWT in request"
MSG_MissingCSRFTokenError = "Missing CSRF double submit token in request"
MSG_TokenTypeError = "Bad token type"
MSG_RevokedTokenError = "Invalid token"
MSG_TokenRequiredError = "Token required"
MSG_FreshTokenRequiredError = "Fresh token required"
MSG_AccessTokenRequiredError = "Access token required"
MSG_RefreshTokenRequiredError = "Refresh token required"
MSG_CSRFError = "CSRF double submit does not match"
MSG_JWTDecodeError = "Invalid Token"

async def _error_handler(
self,
request: Request,
exc: exceptions.AuthXException,
status_code: int,
message: str,
message: Optional[str],
) -> JSONResponse:
"""Generate the async function to be decorated by `FastAPI.exception_handler` decorator

Expand All @@ -40,10 +39,17 @@ async def _error_handler(
Returns:
JSONResponse: The JSON response.
"""
msg = exc.args[0] if message is None else message
if message is None:
default_message = str(exc)
attr_name = f"MSG_{exc.__class__.__name__}"
message = getattr(self, attr_name, default_message)

return JSONResponse(
status_code=status_code,
content={"message": msg, "error_type": exc.__class__.__name__},
content={
"message": message,
"error_type": exc.__class__.__name__,
},
)

def _set_app_exception_handler(
Expand All @@ -53,11 +59,15 @@ def _set_app_exception_handler(
status_code: int,
message: Optional[str],
) -> None:
app.exception_handler(exception)(
lambda request, exc=exception: self._error_handler(
request, exc, status_code, message or self.MSG_DEFAULT
)
)
async def exception_handler_wrapper(
request: Request, exc: exceptions.AuthXException
) -> JSONResponse:
return await self._error_handler(request, exc, status_code, message)

# Add the exception handler to the FastAPI application
# The exception handler will be called when the specified exception is raised, and the status code and message will be returned
# The exception handler will return a JSONResponse with the specified status code and message
app.exception_handler(exception)(exception_handler_wrapper)

def handle_errors(self, app: FastAPI) -> None:
"""Add the `FastAPI.exception_handlers` relative to AuthX exceptions
Expand All @@ -72,53 +82,53 @@ def handle_errors(self, app: FastAPI) -> None:
app,
exception=exceptions.MissingTokenError,
status_code=401,
message=self.MSG_MISSING_TOKEN_ERROR,
message=self.MSG_TokenError,
)
self._set_app_exception_handler(
app,
exception=exceptions.MissingCSRFTokenError,
status_code=401,
message=self.MSG_MISSING_CSRF_ERROR,
message=self.MSG_MissingCSRFTokenError,
)
self._set_app_exception_handler(
app,
exception=exceptions.TokenTypeError,
status_code=401,
message=self.MSG_TOKEN_TYPE_ERROR,
message=self.MSG_TokenTypeError,
)
self._set_app_exception_handler(
app,
exception=exceptions.RevokedTokenError,
status_code=401,
message=self.MSG_REVOKED_TOKEN_ERROR,
message=self.MSG_RevokedTokenError,
)
self._set_app_exception_handler(
app,
exception=exceptions.TokenRequiredError,
status_code=401,
message=self.MSG_TOKEN_REQUIRED_ERROR,
message=self.MSG_TokenRequiredError,
)
self._set_app_exception_handler(
app,
exception=exceptions.FreshTokenRequiredError,
status_code=401,
message=self.MSG_FRESH_TOKEN_REQUIRED_ERROR,
message=self.MSG_FreshTokenRequiredError,
)
self._set_app_exception_handler(
app,
exception=exceptions.AccessTokenRequiredError,
status_code=401,
message=self.MSG_ACCESS_TOKEN_REQUIRED_ERROR,
message=self.MSG_AccessTokenRequiredError,
)
self._set_app_exception_handler(
app,
exception=exceptions.RefreshTokenRequiredError,
status_code=401,
message=self.MSG_REFRESH_TOKEN_REQUIRED_ERROR,
message=self.MSG_RefreshTokenRequiredError,
)
self._set_app_exception_handler(
app,
exception=exceptions.CSRFError,
status_code=401,
message=self.MSG_CSRF_ERROR,
message=self.MSG_CSRFError,
)
7 changes: 7 additions & 0 deletions docs/api/extra/cache.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
# HTTPCacheBackend

!!! warning
You need to install dependencies to use The HTTP Cache.

```console
$ pip install authx_extra[redis]
```

::: authx_extra.cache.HTTPCacheBackend
7 changes: 7 additions & 0 deletions docs/api/extra/metrics.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
# MetricsMiddleware

!!! warning
You need to install dependencies to use The Prometheus Metrics Middleware.

```console
$ pip install authx_extra[prometheus]
```

::: authx_extra.metrics.MetricsMiddleware
7 changes: 7 additions & 0 deletions docs/api/extra/profiler.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
# ProfilerMiddleware

!!! warning
You need to install dependencies to use The Profiler.

```console
$ pip install authx_extra[profiler]
```

::: authx_extra.profiler.ProfilerMiddleware
2 changes: 1 addition & 1 deletion docs/extra/Cache.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ Make sure to have the necessary dependencies installed:
<div class="termy">

```console
$ pip install authx_extra
$ pip install authx_extra[redis]

---> 100%
```
Expand Down
2 changes: 1 addition & 1 deletion docs/extra/Metrics.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ Make sure to have the necessary dependencies installed (e.g., `prometheus_client
<div class="termy">

```console
$ pip install authx_extra
$ pip install authx_extra[prometheus]

---> 100%
```
Expand Down
2 changes: 1 addition & 1 deletion docs/extra/profiler.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ Make sure to have the necessary dependencies installed:
<div class="termy">

```console
$ pip install authx_extra
$ pip install authx_extra[profiler]

---> 100%
```
Expand Down
88 changes: 88 additions & 0 deletions tests/test_errors.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import json
from typing import Type

import pytest
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from fastapi.testclient import TestClient

import authx.exceptions as exc
from authx import AuthX
from authx._internal import _ErrorHandler


@pytest.fixture(scope="function")
Expand All @@ -18,6 +21,16 @@ def app():
return FastAPI()


@pytest.fixture
def client(app):
return TestClient(app)


@pytest.fixture(scope="class")
def error_handler():
return _ErrorHandler()


@pytest.mark.asyncio
async def test_error_handler(authx: AuthX):
error_handler = authx._error_handler(
Expand Down Expand Up @@ -73,3 +86,78 @@ def test_invalid_token_init():
errors = ["Invalid signature", "Expired token"]
exception = exc.InvalidToken(errors)
assert exception.errors == errors


async def create_exception_route(app: FastAPI, exception: Type[Exception]):
@app.get("/")
async def route():
raise exception


@pytest.mark.asyncio
@pytest.mark.parametrize(
"exception,status_code,message,expected_message",
[
(exc.TokenError, 401, None, _ErrorHandler.MSG_TokenError),
(
exc.MissingTokenError,
401,
None,
_ErrorHandler.MSG_MissingTokenError,
),
(
exc.MissingCSRFTokenError,
401,
None,
_ErrorHandler.MSG_MissingCSRFTokenError,
),
(exc.TokenTypeError, 401, None, _ErrorHandler.MSG_TokenTypeError),
(
exc.RevokedTokenError,
401,
None,
_ErrorHandler.MSG_RevokedTokenError,
),
(
exc.TokenRequiredError,
401,
None,
_ErrorHandler.MSG_TokenRequiredError,
),
(
exc.FreshTokenRequiredError,
401,
None,
_ErrorHandler.MSG_FreshTokenRequiredError,
),
(
exc.AccessTokenRequiredError,
401,
None,
_ErrorHandler.MSG_AccessTokenRequiredError,
),
(
exc.RefreshTokenRequiredError,
401,
None,
_ErrorHandler.MSG_RefreshTokenRequiredError,
),
(exc.CSRFError, 401, None, _ErrorHandler.MSG_CSRFError),
(exc.JWTDecodeError, 422, None, _ErrorHandler.MSG_JWTDecodeError),
(exc.AuthXException, 500, "Custom message", "Custom message"),
],
)
async def test_set_app_exception_handler(
app, client, error_handler, exception, status_code, message, expected_message
):
error_handler._set_app_exception_handler(app, exception, status_code, message)

await create_exception_route(app, exception)

response = client.get("/")

assert response.status_code == status_code
assert response.json() == {
"message": expected_message,
"error_type": exception.__name__,
}