55import typing as t
66from io import BytesIO
77from queue import Empty , Queue
8+ from typing import Iterator , Sequence , Tuple , Union
89
910from twisted .internet import reactor
1011from twisted .internet .protocol import Protocol
3637if t .TYPE_CHECKING :
3738 from _typeshed .wsgi import WSGIEnvironment
3839
39-
4040LOG = logging .getLogger (__name__ )
4141
4242
@@ -148,25 +148,46 @@ def to_websocket_environment(request: Request) -> WebSocketEnvironment:
148148 return environ
149149
150150
151+ class TwistedHeaderAdapter (TwistedHeaders ):
152+ """
153+ Custom twisted server Headers object to handle header casing. This was introduced to abstract away the refactoring
154+ that happened in https://github.com/twisted/twisted/pull/12264.
155+ """
156+
157+ def __init__ (self , * args , ** kwargs ):
158+ super ().__init__ (* args , ** kwargs )
159+ self ._caseMappings = {}
160+
161+ def rememberHeaderCasing (self , name : Union [str , bytes ]) -> None :
162+ """
163+ Receives a raw header in its original casing and stores it to later restore the header casing in
164+ ``getAllRawHeaders``.
165+ """
166+ self ._caseMappings [name .lower ()] = name
167+
168+ def getAllRawHeaders (self ) -> Iterator [Tuple [bytes , Sequence [bytes ]]]:
169+ for k , v in self ._rawHeaders .items ():
170+ yield self ._caseMappings .get (k .lower (), k ), v
171+
172+
151173class TwistedRequestAdapter (TwistedRequest ):
152174 """
153175 Custom twisted server Request object to handle header casing.
154176 """
155177
156- rawHeaderList : list [tuple [bytes , bytes ]]
178+ requestHeaders : TwistedHeaderAdapter
179+ responseHeaders : TwistedHeaderAdapter
157180
158181 def __init__ (self , * args , ** kwargs ):
159182 super ().__init__ (* args , ** kwargs )
160- # instantiate case mappings, these are used by `getAllRawHeaders` to restore casing
161- # by default, they are class attributes, so we would override them globally
162- self .requestHeaders ._caseMappings = dict (self .requestHeaders ._caseMappings )
163- self .responseHeaders ._caseMappings = dict (self .responseHeaders ._caseMappings )
183+ self .requestHeaders = TwistedHeaderAdapter ()
184+ self .responseHeaders = TwistedHeaderAdapter ()
164185
165186
166187class HeaderPreservingHTTPChannel (HTTPChannel ):
167188 """
168189 Special HTTPChannel implementation that uses ``Headers._caseMappings`` to retain header casing both for
169- request headers (server -> WSGI), and response headers (WSGI -> client).
190+ request headers (server -> WSGI), and response headers (WSGI -> client).
170191 """
171192
172193 requestFactory = TwistedRequestAdapter
@@ -178,20 +199,30 @@ def protocol_factory():
178199 def headerReceived (self , line ):
179200 if not super ().headerReceived (line ):
180201 return False
181- # remember casing of headers for requests
202+ # remember casing of headers for requests, note that this will only work if TwistedRequestAdapter is used
203+ # as the Request object type, which requires a correct setup of the `Site` object.
182204 header , data = line .split (b":" , 1 )
183205 request : TwistedRequestAdapter = self .requests [- 1 ]
184- request .requestHeaders ._caseMappings [ header . lower ()] = header
206+ request .requestHeaders .rememberHeaderCasing ( header )
185207 return True
186208
187- def writeHeaders (self , version , code , reason , headers ):
209+ def writeHeaders (
210+ self , version : bytes , code : bytes , reason : bytes , headers : list | TwistedHeaders
211+ ):
188212 """Alternative implementation that writes the raw headers instead of sanitized versions."""
189213 responseLine = version + b" " + code + b" " + reason + b"\r \n "
190214 headerSequence = [responseLine ]
191215
192- for name , value in headers :
193- line = name + b": " + value + b"\r \n "
194- headerSequence .append (line )
216+ if isinstance (headers , list ):
217+ # older twisted versions sometime before 24.10 passed a list to this method
218+ for name , value in headers :
219+ line = name + b": " + value + b"\r \n "
220+ headerSequence .append (line )
221+ else :
222+ # newer twisted versions instead pass the headers object
223+ for name , values in headers .getAllRawHeaders ():
224+ line = name + b": " + b"," .join (values ) + b"\r \n "
225+ headerSequence .append (line )
195226
196227 headerSequence .append (b"\r \n " )
197228 self .transport .writeSequence (headerSequence )
@@ -216,7 +247,7 @@ def startResponse(self, *args, **kwargs):
216247 # headers
217248 for header , _ in self .headers :
218249 header = header .encode ("latin-1" )
219- self .request .responseHeaders ._caseMappings [ header . lower ()] = header
250+ self .request .responseHeaders .rememberHeaderCasing ( header )
220251 return result
221252
222253
@@ -441,6 +472,7 @@ class TwistedGateway(Site):
441472
442473 def __init__ (self , gateway : Gateway ):
443474 super ().__init__ (
444- GatewayResource (gateway , reactor , reactor .getThreadPool ()), TwistedRequestAdapter
475+ resource = GatewayResource (gateway , reactor , reactor .getThreadPool ()),
476+ requestFactory = TwistedRequestAdapter ,
445477 )
446478 self .protocol = HeaderPreservingHTTPChannel .protocol_factory
0 commit comments