Skip to content

Commit 0ebcf9f

Browse files
authored
Merge pull request #653 from dimensionalOS/suppress-echos-with-counter
Suppress echos with counter Former-commit-id: bcfee11 [formerly d9a0440] Former-commit-id: ab6fa0e
1 parent ac60cc6 commit 0ebcf9f

1 file changed

Lines changed: 39 additions & 18 deletions

File tree

dimos/protocol/pubsub/shmpubsub.py

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
import time
2727
from collections import defaultdict
2828
from dataclasses import dataclass
29-
from typing import Any, Callable, Dict, Optional
29+
from typing import Any, Callable, Dict, Optional, Tuple
3030

3131
import numpy as np
3232

@@ -82,12 +82,13 @@ class _TopicState:
8282
"cp",
8383
"last_local_payload",
8484
"suppress_counts",
85+
"message_counter",
8586
)
8687

8788
def __init__(self, channel, capacity: int, cp_mod):
8889
self.channel = channel
8990
self.capacity = int(capacity)
90-
self.shape = (self.capacity + 4,) # +4 for uint32 length header
91+
self.shape = (self.capacity + 12,) # +12 for header: length(4) + pid(4) + counter(4)
9192
self.dtype = np.uint8
9293
self.subs: list[Callable[[bytes, str], None]] = []
9394
self.stop = threading.Event()
@@ -96,7 +97,8 @@ def __init__(self, channel, capacity: int, cp_mod):
9697
# TODO: implement an initializer variable for is_cuda once CUDA IPC is in
9798
self.cp = cp_mod
9899
self.last_local_payload: Optional[bytes] = None
99-
self.suppress_counts = defaultdict(int)
100+
self.suppress_counts: Dict[Tuple[int, int], int] = defaultdict(int)
101+
self.message_counter = 0
100102

101103
# ----- init / lifecycle -------------------------------------------------
102104

@@ -158,9 +160,13 @@ def publish(self, topic: str, message: bytes) -> None:
158160
logger.error(f"Payload too large: {L} > capacity {st.capacity}")
159161
raise ValueError(f"Payload too large: {L} > capacity {st.capacity}")
160162

161-
# Mark this payload to suppress its single echo (handles back-to-back publishes)
162-
payload_hash = hashlib.md5(payload_bytes).digest()
163-
st.suppress_counts[payload_hash] += 1
163+
# Create a unique identifier using PID and incrementing counter
164+
pid = os.getpid()
165+
st.message_counter += 1
166+
message_id = (pid, st.message_counter)
167+
168+
# Mark this message to suppress its echo
169+
st.suppress_counts[message_id] += 1
164170

165171
# Synchronous local delivery first (zero extra copies)
166172
for cb in list(st.subs):
@@ -170,11 +176,14 @@ def publish(self, topic: str, message: bytes) -> None:
170176
logger.warn(f"Payload couldn't be pushed to topic: {topic}")
171177
pass
172178

173-
# Build host frame [len:4] + payload and publish
179+
# Build host frame [len:4] + [pid:4] + [counter:4] + payload and publish
180+
# We embed the message ID in the frame for echo suppression
174181
host = np.zeros(st.shape, dtype=st.dtype)
175-
host[:4] = np.frombuffer(struct.pack("<I", L), dtype=np.uint8)
182+
# Pack: length(4) + pid(4) + counter(4) + payload
183+
header = struct.pack("<III", L + 8, pid, st.message_counter) # L+8 for pid+counter
184+
host[:12] = np.frombuffer(header, dtype=np.uint8)
176185
if L:
177-
host[4 : 4 + L] = np.frombuffer(memoryview(payload_bytes), dtype=np.uint8)
186+
host[12 : 12 + L] = np.frombuffer(memoryview(payload_bytes), dtype=np.uint8)
178187

179188
st.channel.publish(host)
180189

@@ -231,7 +240,7 @@ def reconfigure(self, topic: str, *, capacity: int) -> dict:
231240
"""Change payload capacity (bytes) for a topic; returns new descriptor."""
232241
st = self._ensure_topic(topic)
233242
new_cap = int(capacity)
234-
new_shape = (new_cap + 4,)
243+
new_shape = (new_cap + 12,) # +12 for header: length(4) + pid(4) + counter(4)
235244
desc = st.channel.reconfigure(new_shape, np.uint8)
236245
st.capacity = new_cap
237246
st.shape = new_shape
@@ -254,7 +263,7 @@ def _names_for_topic(topic: str, capacity: int) -> tuple[str, str]:
254263
return f"psm_{h}_data", f"psm_{h}_ctrl"
255264

256265
data_name, ctrl_name = _names_for_topic(topic, cap)
257-
ch = CpuShmChannel((cap + 4,), np.uint8, data_name=data_name, ctrl_name=ctrl_name)
266+
ch = CpuShmChannel((cap + 12,), np.uint8, data_name=data_name, ctrl_name=ctrl_name)
258267
st = SharedMemoryPubSubBase._TopicState(ch, cap, None)
259268
self._topics[topic] = st
260269
return st
@@ -270,20 +279,32 @@ def _fanout_loop(self, topic: str, st: _TopicState) -> None:
270279
host = np.array(view, copy=True)
271280

272281
try:
273-
L = struct.unpack("<I", host[:4].tobytes())[0]
274-
if L == 0 or L < 0 or L > st.capacity:
282+
# Read header: length(4) + pid(4) + counter(4)
283+
header = struct.unpack("<III", host[:12].tobytes())
284+
L = header[0]
285+
286+
if L < 8 or L > st.capacity + 8:
275287
continue
276288

277-
payload = host[4 : 4 + L].tobytes()
289+
# Extract PID and counter
290+
pid = header[1]
291+
counter = header[2]
292+
message_id = (pid, counter)
293+
294+
# Extract actual payload (after removing the 8 bytes for pid+counter)
295+
payload_len = L - 8
296+
if payload_len > 0:
297+
payload = host[12 : 12 + payload_len].tobytes()
298+
else:
299+
continue
278300

279301
# Drop exactly the number of local echoes we created
280-
payload_hash = hashlib.md5(payload).digest()
281-
cnt = st.suppress_counts.get(payload_hash, 0)
302+
cnt = st.suppress_counts.get(message_id, 0)
282303
if cnt > 0:
283304
if cnt == 1:
284-
del st.suppress_counts[payload_hash]
305+
del st.suppress_counts[message_id]
285306
else:
286-
st.suppress_counts[payload_hash] = cnt - 1
307+
st.suppress_counts[message_id] = cnt - 1
287308
continue # suppressed
288309

289310
except Exception:

0 commit comments

Comments
 (0)