Skip to content

Commit f0d5daf

Browse files
authored
feat(spans): Distribute span payload keys across Redis cluster (#110593)
Spread span payload sets across Redis cluster nodes to avoid concentrated large traces on a single node. Instead of merging all payloads under {project_id:trace_id}, write them to {project_id:trace_id:span_id} so they shard across nodes. A member-keys tracking set (span-buf:mk) indexes which distributed keys belong to each segment. Three-phase rollout (similar to the ZSET to SET change): - Phase 1 (write-distributed-payloads->set to True): Dual-write to both key formats, read from merged set keys. - Phase 2 (read-distributed-payloads->set to True): Dual-write continues, flusher reads from distributed keys. - Phase 3 (write-merged-payloads->set to False): Stop writing merged payloads.
1 parent c76c8e6 commit f0d5daf

File tree

5 files changed

+382
-15
lines changed

5 files changed

+382
-15
lines changed

src/sentry/options/defaults.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3299,6 +3299,27 @@
32993299
flags=FLAG_PRIORITIZE_DISK | FLAG_AUTOMATOR_MODIFIABLE,
33003300
)
33013301

3302+
# Write payload sets to per-span distributed keys AND merged keys.
3303+
# Flusher reads merged keys as before.
3304+
register(
3305+
"spans.buffer.write-distributed-payloads",
3306+
default=False,
3307+
flags=FLAG_PRIORITIZE_DISK | FLAG_AUTOMATOR_MODIFIABLE,
3308+
)
3309+
# Switch flusher to read from distributed keys instead of merged.
3310+
register(
3311+
"spans.buffer.read-distributed-payloads",
3312+
default=False,
3313+
flags=FLAG_PRIORITIZE_DISK | FLAG_AUTOMATOR_MODIFIABLE,
3314+
)
3315+
# Set to False to stop writing merged keys and skip set merges.
3316+
# Disable after read-distributed-payloads is stable. Rollback: re-enable
3317+
# this flag to resume merged writes before reverting read-distributed-payloads.
3318+
register(
3319+
"spans.buffer.write-merged-payloads",
3320+
default=True,
3321+
flags=FLAG_PRIORITIZE_DISK | FLAG_AUTOMATOR_MODIFIABLE,
3322+
)
33023323
# List of trace_ids to enable debug logging for. Empty = debug off.
33033324
# When set, logs detailed metrics about zunionstore set sizes, key existence, and trace structure.
33043325
register(

src/sentry/scripts/spans/add-buffer.lua

Lines changed: 87 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ ARGS:
2929
- max_segment_bytes -- int -- The maximum number of bytes the segment can contain.
3030
- byte_count -- int -- The total number of bytes in the subsegment.
3131
- zero_copy_dest_threshold -- int -- When > 0, use SMEMBERS+SADD instead of SUNIONSTORE when the destination set exceeds this many bytes.
32+
- write_distributed_payloads -- "true" or "false" -- When true, maintain member-keys tracking sets for distributed payload keys.
33+
- write_merged_payloads -- "true" or "false" -- When false, skip set merges and set keys expire cmds.
3234
- *span_id -- str[] -- The span ids in the subsegment.
3335
3436
RETURNS:
@@ -49,7 +51,9 @@ local set_timeout = tonumber(ARGV[4])
4951
local max_segment_bytes = tonumber(ARGV[5])
5052
local byte_count = tonumber(ARGV[6])
5153
local zero_copy_dest_threshold = tonumber(ARGV[7])
52-
local NUM_ARGS = 7
54+
local write_distributed_payloads = ARGV[8] == "true"
55+
local write_merged_payloads = ARGV[9] == "true"
56+
local NUM_ARGS = 9
5357

5458
local function get_time_ms()
5559
local time = redis.call("TIME")
@@ -102,7 +106,7 @@ local sunionstore_args = {}
102106
-- Updating the redirect set instead is needed when we receive higher level spans
103107
-- for a tree we are assembling as the segment root each span points at in the
104108
-- redirect set changes when a new root is found.
105-
if set_span_id ~= parent_span_id and redis.call("scard", parent_key) > 0 then
109+
if write_merged_payloads and set_span_id ~= parent_span_id and redis.call("scard", parent_key) > 0 then
106110
table.insert(sunionstore_args, parent_key)
107111
end
108112

@@ -113,7 +117,7 @@ for i = NUM_ARGS + 1, NUM_ARGS + num_spans do
113117
table.insert(hset_args, span_id)
114118
table.insert(hset_args, set_span_id)
115119

116-
if not is_root_span then
120+
if not is_root_span and write_merged_payloads then
117121
local span_key = string.format("span-buf:s:{%s}:%s", project_and_trace, span_id)
118122
table.insert(sunionstore_args, span_key)
119123
end
@@ -128,7 +132,83 @@ table.insert(latency_table, {"sunionstore_args_step_latency_ms", sunionstore_arg
128132
-- Merge spans into the parent span set.
129133
-- Used outside the if statement
130134
local arg_cleanup_end_time_ms = sunionstore_args_end_time_ms
131-
if #sunionstore_args > 0 then
135+
-- Maintain member-keys (span-buf:mk) tracking sets so the flusher
136+
-- knows which distributed keys to fetch. This runs in both write-only and
137+
-- full distributed mode.
138+
if write_distributed_payloads then
139+
local member_keys_key = string.format("span-buf:mk:{%s}:%s", project_and_trace, set_span_id)
140+
redis.call("sadd", member_keys_key, parent_span_id)
141+
142+
-- Merge child tracking sets from span_ids that were previously segment roots.
143+
for i = NUM_ARGS + 1, NUM_ARGS + num_spans do
144+
local span_id = ARGV[i]
145+
if span_id ~= parent_span_id then
146+
local child_mk_key = string.format("span-buf:mk:{%s}:%s", project_and_trace, span_id)
147+
local child_members = redis.call("smembers", child_mk_key)
148+
if #child_members > 0 then
149+
redis.call("sadd", member_keys_key, unpack(child_members))
150+
redis.call("del", child_mk_key)
151+
end
152+
end
153+
end
154+
155+
-- Merge parent's tracking set if parent_span_id redirected to a different root.
156+
if set_span_id ~= parent_span_id then
157+
local parent_mk_key = string.format("span-buf:mk:{%s}:%s", project_and_trace, parent_span_id)
158+
local parent_members = redis.call("smembers", parent_mk_key)
159+
if #parent_members > 0 then
160+
redis.call("sadd", member_keys_key, unpack(parent_members))
161+
redis.call("del", parent_mk_key)
162+
end
163+
end
164+
165+
redis.call("expire", member_keys_key, set_timeout)
166+
arg_cleanup_end_time_ms = get_time_ms()
167+
table.insert(latency_table, {"distributed_tracking_step_latency_ms", arg_cleanup_end_time_ms - sunionstore_args_end_time_ms})
168+
end
169+
170+
-- When write_merged_payloads is false, merged set merges are skipped but we
171+
-- still need to merge ic/ibc counters from child keys into the segment root.
172+
if not write_merged_payloads then
173+
local ingested_count_key = string.format("span-buf:ic:%s", set_key)
174+
local ingested_byte_count_key = string.format("span-buf:ibc:%s", set_key)
175+
for i = NUM_ARGS + 1, NUM_ARGS + num_spans do
176+
local span_id = ARGV[i]
177+
if span_id ~= parent_span_id then
178+
local child_merged = string.format("span-buf:s:{%s}:%s", project_and_trace, span_id)
179+
local child_ic_key = string.format("span-buf:ic:%s", child_merged)
180+
local child_ibc_key = string.format("span-buf:ibc:%s", child_merged)
181+
local child_count = redis.call("get", child_ic_key)
182+
local child_byte_count = redis.call("get", child_ibc_key)
183+
if child_count then
184+
redis.call("incrby", ingested_count_key, child_count)
185+
redis.call("del", child_ic_key)
186+
end
187+
if child_byte_count then
188+
redis.call("incrby", ingested_byte_count_key, child_byte_count)
189+
redis.call("del", child_ibc_key)
190+
end
191+
end
192+
end
193+
if set_span_id ~= parent_span_id then
194+
local parent_merged = string.format("span-buf:s:{%s}:%s", project_and_trace, parent_span_id)
195+
local parent_ic_key = string.format("span-buf:ic:%s", parent_merged)
196+
local parent_ibc_key = string.format("span-buf:ibc:%s", parent_merged)
197+
local parent_count = redis.call("get", parent_ic_key)
198+
local parent_byte_count = redis.call("get", parent_ibc_key)
199+
if parent_count then
200+
redis.call("incrby", ingested_count_key, parent_count)
201+
redis.call("del", parent_ic_key)
202+
end
203+
if parent_byte_count then
204+
redis.call("incrby", ingested_byte_count_key, parent_byte_count)
205+
redis.call("del", parent_ibc_key)
206+
end
207+
end
208+
arg_cleanup_end_time_ms = get_time_ms()
209+
table.insert(latency_table, {"distributed_ibc_merge_step_latency_ms", arg_cleanup_end_time_ms - sunionstore_args_end_time_ms})
210+
211+
elseif #sunionstore_args > 0 then
132212
local dest_memory = redis.call("memory", "usage", set_key) or 0
133213
local ingested_byte_count_key = string.format("span-buf:ibc:%s", set_key)
134214
local dest_bytes = tonumber(redis.call("get", ingested_byte_count_key) or 0)
@@ -210,7 +290,9 @@ redis.call("incrby", ingested_byte_count_key, byte_count)
210290
redis.call("expire", ingested_count_key, set_timeout)
211291
redis.call("expire", ingested_byte_count_key, set_timeout)
212292

213-
redis.call("expire", set_key, set_timeout)
293+
if write_merged_payloads then
294+
redis.call("expire", set_key, set_timeout)
295+
end
214296

215297
local ingested_count_end_time_ms = get_time_ms()
216298
local ingested_count_step_latency_ms = ingested_count_end_time_ms - arg_cleanup_end_time_ms

src/sentry/spans/buffer.py

Lines changed: 109 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,12 @@
9191
)
9292
from sentry.spans.consumers.process_segments.types import attribute_value
9393
from sentry.spans.debug_trace_logger import DebugTraceLogger
94-
from sentry.spans.segment_key import SegmentKey, parse_segment_key, segment_key_to_span_id
94+
from sentry.spans.segment_key import (
95+
DistributedPayloadKey,
96+
SegmentKey,
97+
parse_segment_key,
98+
segment_key_to_span_id,
99+
)
95100
from sentry.utils import metrics, redis
96101
from sentry.utils.outcomes import Outcome, track_outcome
97102

@@ -137,6 +142,7 @@ class FlushedSegment(NamedTuple):
137142
queue_key: QueueKey
138143
spans: list[OutputSpan]
139144
project_id: int # Used to track outcomes
145+
distributed_payload_keys: list[DistributedPayloadKey] = [] # For cleanup
140146

141147

142148
class SpansBuffer:
@@ -152,6 +158,7 @@ def __init__(self, assigned_shards: list[int], slice_id: int | None = None):
152158
self._buffer_logger = BufferLogger()
153159
self._flusher_logger = FlusherLogger()
154160
self._debug_trace_logger: DebugTraceLogger | None = None
161+
self._distributed_payload_keys_map: dict[SegmentKey, list[bytes]] = {}
155162

156163
@cached_property
157164
def client(self) -> RedisCluster[bytes] | StrictRedis[bytes]:
@@ -164,6 +171,30 @@ def __reduce__(self):
164171
def _get_span_key(self, project_and_trace: str, span_id: str) -> bytes:
165172
return f"span-buf:s:{{{project_and_trace}}}:{span_id}".encode("ascii")
166173

174+
def _get_distributed_payload_key(
175+
self, project_and_trace: str, span_id: str
176+
) -> DistributedPayloadKey:
177+
return f"span-buf:s:{{{project_and_trace}:{span_id}}}:{span_id}".encode("ascii")
178+
179+
def _get_payload_key_index(self, segment_key: SegmentKey) -> bytes:
180+
project_id, trace_id, span_id = parse_segment_key(segment_key)
181+
return b"span-buf:mk:{%s:%s}:%s" % (project_id, trace_id, span_id)
182+
183+
def _cleanup_distributed_keys(self, segment_keys: set[SegmentKey]) -> None:
184+
"""Delete member-keys tracking sets and distributed payload keys for the
185+
given segments, and remove them from the payload keys map so
186+
done_flush_segments doesn't try again."""
187+
with self.client.pipeline(transaction=False) as p:
188+
for key in segment_keys:
189+
payload_keys = self._distributed_payload_keys_map.get(key, [])
190+
if payload_keys:
191+
mk_key = self._get_payload_key_index(key)
192+
p.delete(mk_key)
193+
for batch in itertools.batched(payload_keys, 100):
194+
p.unlink(*batch)
195+
self._distributed_payload_keys_map.pop(key, None)
196+
p.execute()
197+
167198
@metrics.wraps("spans.buffer.process_spans")
168199
def process_spans(self, spans: Sequence[Span], now: int):
169200
"""
@@ -186,6 +217,8 @@ def process_spans(self, spans: Sequence[Span], now: int):
186217
max_segment_bytes = options.get("spans.buffer.max-segment-bytes")
187218
max_spans_per_evalsha = options.get("spans.buffer.max-spans-per-evalsha")
188219
zero_copy_threshold = options.get("spans.buffer.zero-copy-dest-threshold-bytes")
220+
write_distributed_payloads = options.get("spans.buffer.write-distributed-payloads")
221+
write_merged_payloads = options.get("spans.buffer.write-merged-payloads")
189222

190223
result_meta = []
191224
is_root_span_count = 0
@@ -214,8 +247,17 @@ def process_spans(self, spans: Sequence[Span], now: int):
214247
with self.client.pipeline(transaction=False) as p:
215248
for (project_and_trace, parent_span_id), subsegment in batch:
216249
set_members = self._prepare_payloads(subsegment)
217-
set_key = self._get_span_key(project_and_trace, parent_span_id)
218-
p.sadd(set_key, *set_members.keys())
250+
if write_distributed_payloads:
251+
# Write to distributed key.
252+
dist_key = self._get_distributed_payload_key(
253+
project_and_trace, parent_span_id
254+
)
255+
p.sadd(dist_key, *set_members.keys())
256+
p.expire(dist_key, redis_ttl)
257+
258+
if write_merged_payloads:
259+
set_key = self._get_span_key(project_and_trace, parent_span_id)
260+
p.sadd(set_key, *set_members.keys())
219261

220262
p.execute()
221263

@@ -257,6 +299,8 @@ def process_spans(self, spans: Sequence[Span], now: int):
257299
max_segment_bytes,
258300
byte_count,
259301
zero_copy_threshold,
302+
"true" if write_distributed_payloads else "false",
303+
"true" if write_merged_payloads else "false",
260304
*span_ids,
261305
)
262306

@@ -565,6 +609,7 @@ def flush_segments(self, now: int) -> dict[SegmentKey, FlushedSegment]:
565609
queue_key=queue_key,
566610
spans=output_spans,
567611
project_id=int(project_id.decode("ascii")),
612+
distributed_payload_keys=self._distributed_payload_keys_map.get(segment_key, []),
568613
)
569614
num_has_root_spans += int(has_root_span)
570615

@@ -659,13 +704,50 @@ def _load_segment_data(self, segment_keys: list[SegmentKey]) -> dict[SegmentKey,
659704

660705
page_size = options.get("spans.buffer.segment-page-size")
661706
max_segment_bytes = options.get("spans.buffer.max-segment-bytes")
707+
read_distributed_payloads = options.get("spans.buffer.read-distributed-payloads")
708+
write_distributed_payloads = options.get("spans.buffer.write-distributed-payloads")
662709

663710
payloads: dict[SegmentKey, list[bytes]] = {key: [] for key in segment_keys}
664-
cursors = {key: 0 for key in segment_keys}
665711
sizes: dict[SegmentKey, int] = {key: 0 for key in segment_keys}
666712
self._last_decompress_latency_ms = 0
667713
decompress_latency_ms = 0.0
668714

715+
# Maps each scan key back to the segment it belongs to. For merged
716+
# keys these are the same; for distributed keys many map to one segment.
717+
scan_key_to_segment: dict[SegmentKey | DistributedPayloadKey, SegmentKey] = {}
718+
719+
# When read_distributed_payloads is off, scan merged segment keys directly.
720+
# When on, skip them — all data lives in distributed keys.
721+
cursors: dict[bytes, int] = {}
722+
if not read_distributed_payloads:
723+
for key in segment_keys:
724+
scan_key_to_segment[key] = key
725+
cursors[key] = 0
726+
727+
self._distributed_payload_keys_map = {}
728+
729+
if write_distributed_payloads:
730+
with self.client.pipeline(transaction=False) as p:
731+
for key in segment_keys:
732+
p.smembers(self._get_payload_key_index(key))
733+
mk_results = p.execute()
734+
735+
for key, sub_span_ids in zip(segment_keys, mk_results):
736+
project_id, trace_id, _ = parse_segment_key(key)
737+
pat = f"{project_id.decode('ascii')}:{trace_id.decode('ascii')}"
738+
distributed_keys: list[bytes] = []
739+
for sub_span_id in sub_span_ids:
740+
distributed_key = self._get_distributed_payload_key(
741+
pat, sub_span_id.decode("ascii")
742+
)
743+
distributed_keys.append(distributed_key)
744+
if read_distributed_payloads:
745+
scan_key_to_segment[distributed_key] = key
746+
cursors[distributed_key] = 0
747+
self._distributed_payload_keys_map[key] = distributed_keys
748+
749+
dropped_segments: set[SegmentKey] = set()
750+
669751
def _add_spans(key: SegmentKey, raw_data: bytes) -> bool:
670752
"""
671753
Decompress and add spans to the segment. Returns False if the
@@ -683,6 +765,7 @@ def _add_spans(key: SegmentKey, raw_data: bytes) -> bool:
683765
logger.warning("Skipping too large segment, byte size %s", sizes[key])
684766
payloads.pop(key, None)
685767
sizes.pop(key, None)
768+
dropped_segments.add(key)
686769
return False
687770

688771
payloads[key].extend(decompressed)
@@ -698,10 +781,15 @@ def _add_spans(key: SegmentKey, raw_data: bytes) -> bool:
698781
scan_results = p.execute()
699782

700783
for key, (cursor, scan_values) in zip(current_keys, scan_results):
784+
segment_key = scan_key_to_segment[key]
785+
if segment_key in dropped_segments:
786+
cursors.pop(key, None)
787+
continue
788+
701789
size_exceeded = False
702790
for scan_value in scan_values:
703-
if key in payloads:
704-
if not _add_spans(key, scan_value):
791+
if segment_key in payloads:
792+
if not _add_spans(segment_key, scan_value):
705793
size_exceeded = True
706794

707795
if size_exceeded:
@@ -711,6 +799,9 @@ def _add_spans(key: SegmentKey, raw_data: bytes) -> bool:
711799
else:
712800
cursors[key] = cursor
713801

802+
if dropped_segments:
803+
self._cleanup_distributed_keys(dropped_segments)
804+
714805
# Fetch ingested counts for all segments to calculate dropped spans
715806
with self.client.pipeline(transaction=False) as p:
716807
for key in segment_keys:
@@ -743,18 +834,18 @@ def _add_spans(key: SegmentKey, raw_data: bytes) -> bool:
743834
continue
744835

745836
project_id_bytes, _, _ = parse_segment_key(key)
746-
project_id = int(project_id_bytes)
837+
project_id_int = int(project_id_bytes)
747838
try:
748-
project = Project.objects.get_from_cache(id=project_id)
839+
project = Project.objects.get_from_cache(id=project_id_int)
749840
except Project.DoesNotExist:
750841
logger.warning(
751842
"Project does not exist for segment with dropped spans",
752-
extra={"project_id": project_id},
843+
extra={"project_id": project_id_int},
753844
)
754845
else:
755846
track_outcome(
756847
org_id=project.organization_id,
757-
project_id=project_id,
848+
project_id=project_id_int,
758849
key_id=None,
759850
outcome=Outcome.INVALID,
760851
reason="segment_too_large",
@@ -800,6 +891,14 @@ def done_flush_segments(self, segment_keys: dict[SegmentKey, FlushedSegment]):
800891
span_ids = [output_span.payload["span_id"] for output_span in span_batch]
801892
p.hdel(redirect_map_key, *span_ids)
802893

894+
if flushed_segment.distributed_payload_keys:
895+
mk_key = self._get_payload_key_index(segment_key)
896+
p.delete(mk_key)
897+
for distributed_key_batch in itertools.batched(
898+
flushed_segment.distributed_payload_keys, 100
899+
):
900+
p.unlink(*distributed_key_batch)
901+
803902
for queue_key, keys in queue_removals.items():
804903
for key_batch in itertools.batched(keys, 100):
805904
p.zrem(queue_key, *key_batch)

0 commit comments

Comments
 (0)