11import abc
22import logging
3+ import re
34import types as python_types
45import typing
6+ from collections import OrderedDict
57
68from opentelemetry .trace .status import Status
79from opentelemetry .util import types
10+ from opentelemetry .util .tracestate import (
11+ _DELIMITER_PATTERN ,
12+ _MEMBER_PATTERN ,
13+ _TRACECONTEXT_MAXIMUM_TRACESTATE_KEYS ,
14+ _is_valid_pair ,
15+ )
816
917_logger = logging .getLogger (__name__ )
1018
@@ -135,7 +143,7 @@ def sampled(self) -> bool:
135143DEFAULT_TRACE_OPTIONS = TraceFlags .get_default ()
136144
137145
138- class TraceState (typing .Dict [str , str ]):
146+ class TraceState (typing .Mapping [str , str ]):
139147 """A list of key-value pairs representing vendor-specific trace info.
140148
141149 Keys and values are strings of up to 256 printable US-ASCII characters.
@@ -146,10 +154,186 @@ class TraceState(typing.Dict[str, str]):
146154 https://www.w3.org/TR/trace-context/#tracestate-field
147155 """
148156
157+ def __init__ (
158+ self ,
159+ entries : typing .Optional [
160+ typing .Sequence [typing .Tuple [str , str ]]
161+ ] = None ,
162+ ) -> None :
163+ self ._dict = OrderedDict () # type: OrderedDict[str, str]
164+ if entries is None :
165+ return
166+ if len (entries ) > _TRACECONTEXT_MAXIMUM_TRACESTATE_KEYS :
167+ _logger .warning (
168+ "There can't be more than %s key/value pairs." ,
169+ _TRACECONTEXT_MAXIMUM_TRACESTATE_KEYS ,
170+ )
171+ return
172+
173+ for key , value in entries :
174+ if _is_valid_pair (key , value ):
175+ if key in self ._dict :
176+ _logger .warning ("Duplicate key: %s found." , key )
177+ continue
178+ self ._dict [key ] = value
179+ else :
180+ _logger .warning (
181+ "Invalid key/value pair (%s, %s) found." , key , value
182+ )
183+
184+ def __getitem__ (self , key : str ) -> typing .Optional [str ]: # type: ignore
185+ return self ._dict .get (key )
186+
187+ def __iter__ (self ) -> typing .Iterator [str ]:
188+ return iter (self ._dict )
189+
190+ def __len__ (self ) -> int :
191+ return len (self ._dict )
192+
193+ def __repr__ (self ) -> str :
194+ pairs = [
195+ "{key=%s, value=%s}" % (key , value )
196+ for key , value in self ._dict .items ()
197+ ]
198+ return str (pairs )
199+
200+ def add (self , key : str , value : str ) -> "TraceState" :
201+ """Adds a key-value pair to tracestate. The provided pair should
202+ adhere to w3c tracestate identifiers format.
203+
204+ Args:
205+ key: A valid tracestate key to add
206+ value: A valid tracestate value to add
207+
208+ Returns:
209+ A new TraceState with the modifications applied.
210+
211+ If the provided key-value pair is invalid or results in tracestate
212+ that violates tracecontext specification, they are discarded and
213+ same tracestate will be returned.
214+ """
215+ if not _is_valid_pair (key , value ):
216+ _logger .warning (
217+ "Invalid key/value pair (%s, %s) found." , key , value
218+ )
219+ return self
220+ # There can be a maximum of 32 pairs
221+ if len (self ) >= _TRACECONTEXT_MAXIMUM_TRACESTATE_KEYS :
222+ _logger .warning ("There can't be more 32 key/value pairs." )
223+ return self
224+ # Duplicate entries are not allowed
225+ if key in self ._dict :
226+ _logger .warning ("The provided key %s already exists." , key )
227+ return self
228+ new_state = [(key , value )] + list (self ._dict .items ())
229+ return TraceState (new_state )
230+
231+ def update (self , key : str , value : str ) -> "TraceState" :
232+ """Updates a key-value pair in tracestate. The provided pair should
233+ adhere to w3c tracestate identifiers format.
234+
235+ Args:
236+ key: A valid tracestate key to update
237+ value: A valid tracestate value to update for key
238+
239+ Returns:
240+ A new TraceState with the modifications applied.
241+
242+ If the provided key-value pair is invalid or results in tracestate
243+ that violates tracecontext specification, they are discarded and
244+ same tracestate will be returned.
245+ """
246+ if not _is_valid_pair (key , value ):
247+ _logger .warning (
248+ "Invalid key/value pair (%s, %s) found." , key , value
249+ )
250+ return self
251+ prev_state = self ._dict .copy ()
252+ prev_state [key ] = value
253+ prev_state .move_to_end (key , last = False )
254+ new_state = list (prev_state .items ())
255+ return TraceState (new_state )
256+
257+ def delete (self , key : str ) -> "TraceState" :
258+ """Deletes a key-value from tracestate.
259+
260+ Args:
261+ key: A valid tracestate key to remove key-value pair from tracestate
262+
263+ Returns:
264+ A new TraceState with the modifications applied.
265+
266+ If the provided key-value pair is invalid or results in tracestate
267+ that violates tracecontext specification, they are discarded and
268+ same tracestate will be returned.
269+ """
270+ if key not in self ._dict :
271+ _logger .warning ("The provided key %s doesn't exist." , key )
272+ return self
273+ prev_state = self ._dict .copy ()
274+ prev_state .pop (key )
275+ new_state = list (prev_state .items ())
276+ return TraceState (new_state )
277+
278+ def to_header (self ) -> str :
279+ """Creates a w3c tracestate header from a TraceState.
280+
281+ Returns:
282+ A string that adheres to the w3c tracestate
283+ header format.
284+ """
285+ return "," .join (key + "=" + value for key , value in self ._dict .items ())
286+
287+ @classmethod
288+ def from_header (cls , header_list : typing .List [str ]) -> "TraceState" :
289+ """Parses one or more w3c tracestate header into a TraceState.
290+
291+ Args:
292+ header_list: one or more w3c tracestate headers.
293+
294+ Returns:
295+ A valid TraceState that contains values extracted from
296+ the tracestate header.
297+
298+ If the format of one headers is illegal, all values will
299+ be discarded and an empty tracestate will be returned.
300+
301+ If the number of keys is beyond the maximum, all values
302+ will be discarded and an empty tracestate will be returned.
303+ """
304+ pairs = OrderedDict ()
305+ for header in header_list :
306+ for member in re .split (_DELIMITER_PATTERN , header ):
307+ # empty members are valid, but no need to process further.
308+ if not member :
309+ continue
310+ match = _MEMBER_PATTERN .fullmatch (member )
311+ if not match :
312+ _logger .warning (
313+ "Member doesn't match the w3c identifiers format %s" ,
314+ member ,
315+ )
316+ return cls ()
317+ key , _eq , value = match .groups ()
318+ # duplicate keys are not legal in header
319+ if key in pairs :
320+ return cls ()
321+ pairs [key ] = value
322+ return cls (list (pairs .items ()))
323+
149324 @classmethod
150325 def get_default (cls ) -> "TraceState" :
151326 return cls ()
152327
328+ def keys (self ) -> typing .KeysView [str ]:
329+ return self ._dict .keys ()
330+
331+ def items (self ) -> typing .ItemsView [str , str ]:
332+ return self ._dict .items ()
333+
334+ def values (self ) -> typing .ValuesView [str ]:
335+ return self ._dict .values ()
336+
153337
154338DEFAULT_TRACE_STATE = TraceState .get_default ()
155339
0 commit comments