Skip to content

Commit ac3aafe

Browse files
authored
fix header case handling and support twisted >=24.10 (#34)
1 parent 517d2cf commit ac3aafe

File tree

4 files changed

+73
-23
lines changed

4 files changed

+73
-23
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ dev = [
3838
"websocket-client>=1.7.0",
3939
"coverage[toml]>=5.0.0",
4040
"coveralls>=3.3",
41-
"localstack-twisted",
41+
"twisted>=24",
4242
"ruff==0.1.0"
4343
]
4444
docs = [

rolo/serving/twisted.py

Lines changed: 47 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import typing as t
66
from io import BytesIO
77
from queue import Empty, Queue
8+
from typing import Iterator, Sequence, Tuple, Union
89

910
from twisted.internet import reactor
1011
from twisted.internet.protocol import Protocol
@@ -36,7 +37,6 @@
3637
if t.TYPE_CHECKING:
3738
from _typeshed.wsgi import WSGIEnvironment
3839

39-
4040
LOG = logging.getLogger(__name__)
4141

4242

@@ -148,25 +148,46 @@ def to_websocket_environment(request: Request) -> WebSocketEnvironment:
148148
return environ
149149

150150

151+
class TwistedHeaderAdapter(TwistedHeaders):
152+
"""
153+
Custom twisted server Headers object to handle header casing. This was introduced to abstract away the refactoring
154+
that happened in https://github.com/twisted/twisted/pull/12264.
155+
"""
156+
157+
def __init__(self, *args, **kwargs):
158+
super().__init__(*args, **kwargs)
159+
self._caseMappings = {}
160+
161+
def rememberHeaderCasing(self, name: Union[str, bytes]) -> None:
162+
"""
163+
Receives a raw header in its original casing and stores it to later restore the header casing in
164+
``getAllRawHeaders``.
165+
"""
166+
self._caseMappings[name.lower()] = name
167+
168+
def getAllRawHeaders(self) -> Iterator[Tuple[bytes, Sequence[bytes]]]:
169+
for k, v in self._rawHeaders.items():
170+
yield self._caseMappings.get(k.lower(), k), v
171+
172+
151173
class TwistedRequestAdapter(TwistedRequest):
152174
"""
153175
Custom twisted server Request object to handle header casing.
154176
"""
155177

156-
rawHeaderList: list[tuple[bytes, bytes]]
178+
requestHeaders: TwistedHeaderAdapter
179+
responseHeaders: TwistedHeaderAdapter
157180

158181
def __init__(self, *args, **kwargs):
159182
super().__init__(*args, **kwargs)
160-
# instantiate case mappings, these are used by `getAllRawHeaders` to restore casing
161-
# by default, they are class attributes, so we would override them globally
162-
self.requestHeaders._caseMappings = dict(self.requestHeaders._caseMappings)
163-
self.responseHeaders._caseMappings = dict(self.responseHeaders._caseMappings)
183+
self.requestHeaders = TwistedHeaderAdapter()
184+
self.responseHeaders = TwistedHeaderAdapter()
164185

165186

166187
class HeaderPreservingHTTPChannel(HTTPChannel):
167188
"""
168189
Special HTTPChannel implementation that uses ``Headers._caseMappings`` to retain header casing both for
169-
request headers (server -> WSGI), and response headers (WSGI -> client).
190+
request headers (server -> WSGI), and response headers (WSGI -> client).
170191
"""
171192

172193
requestFactory = TwistedRequestAdapter
@@ -178,20 +199,30 @@ def protocol_factory():
178199
def headerReceived(self, line):
179200
if not super().headerReceived(line):
180201
return False
181-
# remember casing of headers for requests
202+
# remember casing of headers for requests, note that this will only work if TwistedRequestAdapter is used
203+
# as the Request object type, which requires a correct setup of the `Site` object.
182204
header, data = line.split(b":", 1)
183205
request: TwistedRequestAdapter = self.requests[-1]
184-
request.requestHeaders._caseMappings[header.lower()] = header
206+
request.requestHeaders.rememberHeaderCasing(header)
185207
return True
186208

187-
def writeHeaders(self, version, code, reason, headers):
209+
def writeHeaders(
210+
self, version: bytes, code: bytes, reason: bytes, headers: list | TwistedHeaders
211+
):
188212
"""Alternative implementation that writes the raw headers instead of sanitized versions."""
189213
responseLine = version + b" " + code + b" " + reason + b"\r\n"
190214
headerSequence = [responseLine]
191215

192-
for name, value in headers:
193-
line = name + b": " + value + b"\r\n"
194-
headerSequence.append(line)
216+
if isinstance(headers, list):
217+
# older twisted versions sometime before 24.10 passed a list to this method
218+
for name, value in headers:
219+
line = name + b": " + value + b"\r\n"
220+
headerSequence.append(line)
221+
else:
222+
# newer twisted versions instead pass the headers object
223+
for name, values in headers.getAllRawHeaders():
224+
line = name + b": " + b",".join(values) + b"\r\n"
225+
headerSequence.append(line)
195226

196227
headerSequence.append(b"\r\n")
197228
self.transport.writeSequence(headerSequence)
@@ -216,7 +247,7 @@ def startResponse(self, *args, **kwargs):
216247
# headers
217248
for header, _ in self.headers:
218249
header = header.encode("latin-1")
219-
self.request.responseHeaders._caseMappings[header.lower()] = header
250+
self.request.responseHeaders.rememberHeaderCasing(header)
220251
return result
221252

222253

@@ -441,6 +472,7 @@ class TwistedGateway(Site):
441472

442473
def __init__(self, gateway: Gateway):
443474
super().__init__(
444-
GatewayResource(gateway, reactor, reactor.getThreadPool()), TwistedRequestAdapter
475+
resource=GatewayResource(gateway, reactor, reactor.getThreadPool()),
476+
requestFactory=TwistedRequestAdapter,
445477
)
446478
self.protocol = HeaderPreservingHTTPChannel.protocol_factory

rolo/testing/pytest.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,11 @@ def serve_twisted_websocket_listener(twisted_reactor, serve_twisted_tcp_server):
257257
"""
258258
from twisted.web.server import Site
259259

260-
from rolo.serving.twisted import HeaderPreservingWSGIResource, WebsocketResourceDecorator
260+
from rolo.serving.twisted import (
261+
HeaderPreservingWSGIResource,
262+
TwistedRequestAdapter,
263+
WebsocketResourceDecorator,
264+
)
261265

262266
def _create(websocket_listener: WebSocketListener):
263267
site = Site(
@@ -266,7 +270,8 @@ def _create(websocket_listener: WebSocketListener):
266270
twisted_reactor, twisted_reactor.getThreadPool(), None
267271
),
268272
websocketListener=websocket_listener,
269-
)
273+
),
274+
requestFactory=TwistedRequestAdapter,
270275
)
271276
site.protocol = HeaderPreservingHTTPChannel.protocol_factory
272277
return serve_twisted_tcp_server(site)

tests/gateway/test_headers.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,17 @@
33
import pytest
44
import requests
55

6+
from rolo import Response
67
from rolo.gateway import Gateway, HandlerChain, RequestContext
78

89

910
@pytest.mark.parametrize("serve_gateway", ["asgi", "twisted"], indirect=True)
1011
def test_raw_header_handling(serve_gateway):
11-
def handler(chain: HandlerChain, context: RequestContext, response):
12+
def handler(chain: HandlerChain, context: RequestContext, response: Response):
1213
response.data = json.dumps({"headers": dict(context.request.headers)})
1314
response.mimetype = "application/json"
1415
response.headers["X-fOO_bar"] = "FooBar"
16+
response.headers["content-md5"] = "af5e58f9a7c4682e1b410f2e9392a539"
1517
return response
1618

1719
gateway = Gateway(request_handlers=[handler])
@@ -22,7 +24,18 @@ def handler(chain: HandlerChain, context: RequestContext, response):
2224
srv.url,
2325
headers={"x-mIxEd-CaSe": "myheader", "X-UPPER__CASE": "uppercase"},
2426
)
25-
returned_headers = response.json()["headers"]
26-
assert "X-UPPER__CASE" in returned_headers
27-
assert "x-mIxEd-CaSe" in returned_headers
28-
assert "X-fOO_bar" in dict(response.headers)
27+
request_headers = response.json()["headers"]
28+
29+
# test default headers
30+
assert "User-Agent" in request_headers
31+
assert "Connection" in request_headers
32+
assert "Host" in request_headers
33+
34+
# test custom headers
35+
assert "X-UPPER__CASE" in request_headers
36+
assert "x-mIxEd-CaSe" in request_headers
37+
38+
response_headers = dict(response.headers)
39+
assert "X-fOO_bar" in response_headers
40+
# even though it's a standard header, it should be in the original case
41+
assert "content-md5" in response_headers

0 commit comments

Comments
 (0)