9191)
9292from sentry .spans .consumers .process_segments .types import attribute_value
9393from sentry .spans .debug_trace_logger import DebugTraceLogger
94- from sentry .spans .segment_key import SegmentKey , parse_segment_key , segment_key_to_span_id
94+ from sentry .spans .segment_key import (
95+ DistributedPayloadKey ,
96+ SegmentKey ,
97+ parse_segment_key ,
98+ segment_key_to_span_id ,
99+ )
95100from sentry .utils import metrics , redis
96101from sentry .utils .outcomes import Outcome , track_outcome
97102
@@ -137,6 +142,7 @@ class FlushedSegment(NamedTuple):
137142 queue_key : QueueKey
138143 spans : list [OutputSpan ]
139144 project_id : int # Used to track outcomes
145+ distributed_payload_keys : list [DistributedPayloadKey ] = [] # For cleanup
140146
141147
142148class SpansBuffer :
@@ -152,6 +158,7 @@ def __init__(self, assigned_shards: list[int], slice_id: int | None = None):
152158 self ._buffer_logger = BufferLogger ()
153159 self ._flusher_logger = FlusherLogger ()
154160 self ._debug_trace_logger : DebugTraceLogger | None = None
161+ self ._distributed_payload_keys_map : dict [SegmentKey , list [bytes ]] = {}
155162
156163 @cached_property
157164 def client (self ) -> RedisCluster [bytes ] | StrictRedis [bytes ]:
@@ -164,6 +171,30 @@ def __reduce__(self):
164171 def _get_span_key (self , project_and_trace : str , span_id : str ) -> bytes :
165172 return f"span-buf:s:{{{ project_and_trace } }}:{ span_id } " .encode ("ascii" )
166173
174+ def _get_distributed_payload_key (
175+ self , project_and_trace : str , span_id : str
176+ ) -> DistributedPayloadKey :
177+ return f"span-buf:s:{{{ project_and_trace } :{ span_id } }}:{ span_id } " .encode ("ascii" )
178+
179+ def _get_payload_key_index (self , segment_key : SegmentKey ) -> bytes :
180+ project_id , trace_id , span_id = parse_segment_key (segment_key )
181+ return b"span-buf:mk:{%s:%s}:%s" % (project_id , trace_id , span_id )
182+
183+ def _cleanup_distributed_keys (self , segment_keys : set [SegmentKey ]) -> None :
184+ """Delete member-keys tracking sets and distributed payload keys for the
185+ given segments, and remove them from the payload keys map so
186+ done_flush_segments doesn't try again."""
187+ with self .client .pipeline (transaction = False ) as p :
188+ for key in segment_keys :
189+ payload_keys = self ._distributed_payload_keys_map .get (key , [])
190+ if payload_keys :
191+ mk_key = self ._get_payload_key_index (key )
192+ p .delete (mk_key )
193+ for batch in itertools .batched (payload_keys , 100 ):
194+ p .unlink (* batch )
195+ self ._distributed_payload_keys_map .pop (key , None )
196+ p .execute ()
197+
167198 @metrics .wraps ("spans.buffer.process_spans" )
168199 def process_spans (self , spans : Sequence [Span ], now : int ):
169200 """
@@ -186,6 +217,8 @@ def process_spans(self, spans: Sequence[Span], now: int):
186217 max_segment_bytes = options .get ("spans.buffer.max-segment-bytes" )
187218 max_spans_per_evalsha = options .get ("spans.buffer.max-spans-per-evalsha" )
188219 zero_copy_threshold = options .get ("spans.buffer.zero-copy-dest-threshold-bytes" )
220+ write_distributed_payloads = options .get ("spans.buffer.write-distributed-payloads" )
221+ write_merged_payloads = options .get ("spans.buffer.write-merged-payloads" )
189222
190223 result_meta = []
191224 is_root_span_count = 0
@@ -214,8 +247,17 @@ def process_spans(self, spans: Sequence[Span], now: int):
214247 with self .client .pipeline (transaction = False ) as p :
215248 for (project_and_trace , parent_span_id ), subsegment in batch :
216249 set_members = self ._prepare_payloads (subsegment )
217- set_key = self ._get_span_key (project_and_trace , parent_span_id )
218- p .sadd (set_key , * set_members .keys ())
250+ if write_distributed_payloads :
251+ # Write to distributed key.
252+ dist_key = self ._get_distributed_payload_key (
253+ project_and_trace , parent_span_id
254+ )
255+ p .sadd (dist_key , * set_members .keys ())
256+ p .expire (dist_key , redis_ttl )
257+
258+ if write_merged_payloads :
259+ set_key = self ._get_span_key (project_and_trace , parent_span_id )
260+ p .sadd (set_key , * set_members .keys ())
219261
220262 p .execute ()
221263
@@ -257,6 +299,8 @@ def process_spans(self, spans: Sequence[Span], now: int):
257299 max_segment_bytes ,
258300 byte_count ,
259301 zero_copy_threshold ,
302+ "true" if write_distributed_payloads else "false" ,
303+ "true" if write_merged_payloads else "false" ,
260304 * span_ids ,
261305 )
262306
@@ -565,6 +609,7 @@ def flush_segments(self, now: int) -> dict[SegmentKey, FlushedSegment]:
565609 queue_key = queue_key ,
566610 spans = output_spans ,
567611 project_id = int (project_id .decode ("ascii" )),
612+ distributed_payload_keys = self ._distributed_payload_keys_map .get (segment_key , []),
568613 )
569614 num_has_root_spans += int (has_root_span )
570615
@@ -659,13 +704,50 @@ def _load_segment_data(self, segment_keys: list[SegmentKey]) -> dict[SegmentKey,
659704
660705 page_size = options .get ("spans.buffer.segment-page-size" )
661706 max_segment_bytes = options .get ("spans.buffer.max-segment-bytes" )
707+ read_distributed_payloads = options .get ("spans.buffer.read-distributed-payloads" )
708+ write_distributed_payloads = options .get ("spans.buffer.write-distributed-payloads" )
662709
663710 payloads : dict [SegmentKey , list [bytes ]] = {key : [] for key in segment_keys }
664- cursors = {key : 0 for key in segment_keys }
665711 sizes : dict [SegmentKey , int ] = {key : 0 for key in segment_keys }
666712 self ._last_decompress_latency_ms = 0
667713 decompress_latency_ms = 0.0
668714
715+ # Maps each scan key back to the segment it belongs to. For merged
716+ # keys these are the same; for distributed keys many map to one segment.
717+ scan_key_to_segment : dict [SegmentKey | DistributedPayloadKey , SegmentKey ] = {}
718+
719+ # When read_distributed_payloads is off, scan merged segment keys directly.
720+ # When on, skip them — all data lives in distributed keys.
721+ cursors : dict [bytes , int ] = {}
722+ if not read_distributed_payloads :
723+ for key in segment_keys :
724+ scan_key_to_segment [key ] = key
725+ cursors [key ] = 0
726+
727+ self ._distributed_payload_keys_map = {}
728+
729+ if write_distributed_payloads :
730+ with self .client .pipeline (transaction = False ) as p :
731+ for key in segment_keys :
732+ p .smembers (self ._get_payload_key_index (key ))
733+ mk_results = p .execute ()
734+
735+ for key , sub_span_ids in zip (segment_keys , mk_results ):
736+ project_id , trace_id , _ = parse_segment_key (key )
737+ pat = f"{ project_id .decode ('ascii' )} :{ trace_id .decode ('ascii' )} "
738+ distributed_keys : list [bytes ] = []
739+ for sub_span_id in sub_span_ids :
740+ distributed_key = self ._get_distributed_payload_key (
741+ pat , sub_span_id .decode ("ascii" )
742+ )
743+ distributed_keys .append (distributed_key )
744+ if read_distributed_payloads :
745+ scan_key_to_segment [distributed_key ] = key
746+ cursors [distributed_key ] = 0
747+ self ._distributed_payload_keys_map [key ] = distributed_keys
748+
749+ dropped_segments : set [SegmentKey ] = set ()
750+
669751 def _add_spans (key : SegmentKey , raw_data : bytes ) -> bool :
670752 """
671753 Decompress and add spans to the segment. Returns False if the
@@ -683,6 +765,7 @@ def _add_spans(key: SegmentKey, raw_data: bytes) -> bool:
683765 logger .warning ("Skipping too large segment, byte size %s" , sizes [key ])
684766 payloads .pop (key , None )
685767 sizes .pop (key , None )
768+ dropped_segments .add (key )
686769 return False
687770
688771 payloads [key ].extend (decompressed )
@@ -698,10 +781,15 @@ def _add_spans(key: SegmentKey, raw_data: bytes) -> bool:
698781 scan_results = p .execute ()
699782
700783 for key , (cursor , scan_values ) in zip (current_keys , scan_results ):
784+ segment_key = scan_key_to_segment [key ]
785+ if segment_key in dropped_segments :
786+ cursors .pop (key , None )
787+ continue
788+
701789 size_exceeded = False
702790 for scan_value in scan_values :
703- if key in payloads :
704- if not _add_spans (key , scan_value ):
791+ if segment_key in payloads :
792+ if not _add_spans (segment_key , scan_value ):
705793 size_exceeded = True
706794
707795 if size_exceeded :
@@ -711,6 +799,9 @@ def _add_spans(key: SegmentKey, raw_data: bytes) -> bool:
711799 else :
712800 cursors [key ] = cursor
713801
802+ if dropped_segments :
803+ self ._cleanup_distributed_keys (dropped_segments )
804+
714805 # Fetch ingested counts for all segments to calculate dropped spans
715806 with self .client .pipeline (transaction = False ) as p :
716807 for key in segment_keys :
@@ -743,18 +834,18 @@ def _add_spans(key: SegmentKey, raw_data: bytes) -> bool:
743834 continue
744835
745836 project_id_bytes , _ , _ = parse_segment_key (key )
746- project_id = int (project_id_bytes )
837+ project_id_int = int (project_id_bytes )
747838 try :
748- project = Project .objects .get_from_cache (id = project_id )
839+ project = Project .objects .get_from_cache (id = project_id_int )
749840 except Project .DoesNotExist :
750841 logger .warning (
751842 "Project does not exist for segment with dropped spans" ,
752- extra = {"project_id" : project_id },
843+ extra = {"project_id" : project_id_int },
753844 )
754845 else :
755846 track_outcome (
756847 org_id = project .organization_id ,
757- project_id = project_id ,
848+ project_id = project_id_int ,
758849 key_id = None ,
759850 outcome = Outcome .INVALID ,
760851 reason = "segment_too_large" ,
@@ -800,6 +891,14 @@ def done_flush_segments(self, segment_keys: dict[SegmentKey, FlushedSegment]):
800891 span_ids = [output_span .payload ["span_id" ] for output_span in span_batch ]
801892 p .hdel (redirect_map_key , * span_ids )
802893
894+ if flushed_segment .distributed_payload_keys :
895+ mk_key = self ._get_payload_key_index (segment_key )
896+ p .delete (mk_key )
897+ for distributed_key_batch in itertools .batched (
898+ flushed_segment .distributed_payload_keys , 100
899+ ):
900+ p .unlink (* distributed_key_batch )
901+
803902 for queue_key , keys in queue_removals .items ():
804903 for key_batch in itertools .batched (keys , 100 ):
805904 p .zrem (queue_key , * key_batch )
0 commit comments