2626import time
2727from collections import defaultdict
2828from dataclasses import dataclass
29- from typing import Any , Callable , Dict , Optional
29+ from typing import Any , Callable , Dict , Optional , Tuple
3030
3131import numpy as np
3232
@@ -82,12 +82,13 @@ class _TopicState:
8282 "cp" ,
8383 "last_local_payload" ,
8484 "suppress_counts" ,
85+ "message_counter" ,
8586 )
8687
8788 def __init__ (self , channel , capacity : int , cp_mod ):
8889 self .channel = channel
8990 self .capacity = int (capacity )
90- self .shape = (self .capacity + 4 ,) # +4 for uint32 length header
91+ self .shape = (self .capacity + 12 ,) # +12 for header: length(4) + pid(4) + counter(4)
9192 self .dtype = np .uint8
9293 self .subs : list [Callable [[bytes , str ], None ]] = []
9394 self .stop = threading .Event ()
@@ -96,7 +97,8 @@ def __init__(self, channel, capacity: int, cp_mod):
9697 # TODO: implement an initializer variable for is_cuda once CUDA IPC is in
9798 self .cp = cp_mod
9899 self .last_local_payload : Optional [bytes ] = None
99- self .suppress_counts = defaultdict (int )
100+ self .suppress_counts : Dict [Tuple [int , int ], int ] = defaultdict (int )
101+ self .message_counter = 0
100102
101103 # ----- init / lifecycle -------------------------------------------------
102104
@@ -158,9 +160,13 @@ def publish(self, topic: str, message: bytes) -> None:
158160 logger .error (f"Payload too large: { L } > capacity { st .capacity } " )
159161 raise ValueError (f"Payload too large: { L } > capacity { st .capacity } " )
160162
161- # Mark this payload to suppress its single echo (handles back-to-back publishes)
162- payload_hash = hashlib .md5 (payload_bytes ).digest ()
163- st .suppress_counts [payload_hash ] += 1
163+ # Create a unique identifier using PID and incrementing counter
164+ pid = os .getpid ()
165+ st .message_counter += 1
166+ message_id = (pid , st .message_counter )
167+
168+ # Mark this message to suppress its echo
169+ st .suppress_counts [message_id ] += 1
164170
165171 # Synchronous local delivery first (zero extra copies)
166172 for cb in list (st .subs ):
@@ -170,11 +176,14 @@ def publish(self, topic: str, message: bytes) -> None:
170176 logger .warn (f"Payload couldn't be pushed to topic: { topic } " )
171177 pass
172178
173- # Build host frame [len:4] + payload and publish
179+ # Build host frame [len:4] + [pid:4] + [counter:4] + payload and publish
180+ # We embed the message ID in the frame for echo suppression
174181 host = np .zeros (st .shape , dtype = st .dtype )
175- host [:4 ] = np .frombuffer (struct .pack ("<I" , L ), dtype = np .uint8 )
182+ # Pack: length(4) + pid(4) + counter(4) + payload
183+ header = struct .pack ("<III" , L + 8 , pid , st .message_counter ) # L+8 for pid+counter
184+ host [:12 ] = np .frombuffer (header , dtype = np .uint8 )
176185 if L :
177- host [4 : 4 + L ] = np .frombuffer (memoryview (payload_bytes ), dtype = np .uint8 )
186+ host [12 : 12 + L ] = np .frombuffer (memoryview (payload_bytes ), dtype = np .uint8 )
178187
179188 st .channel .publish (host )
180189
@@ -231,7 +240,7 @@ def reconfigure(self, topic: str, *, capacity: int) -> dict:
231240 """Change payload capacity (bytes) for a topic; returns new descriptor."""
232241 st = self ._ensure_topic (topic )
233242 new_cap = int (capacity )
234- new_shape = (new_cap + 4 , )
243+ new_shape = (new_cap + 12 ,) # +12 for header: length(4) + pid(4) + counter(4 )
235244 desc = st .channel .reconfigure (new_shape , np .uint8 )
236245 st .capacity = new_cap
237246 st .shape = new_shape
@@ -254,7 +263,7 @@ def _names_for_topic(topic: str, capacity: int) -> tuple[str, str]:
254263 return f"psm_{ h } _data" , f"psm_{ h } _ctrl"
255264
256265 data_name , ctrl_name = _names_for_topic (topic , cap )
257- ch = CpuShmChannel ((cap + 4 ,), np .uint8 , data_name = data_name , ctrl_name = ctrl_name )
266+ ch = CpuShmChannel ((cap + 12 ,), np .uint8 , data_name = data_name , ctrl_name = ctrl_name )
258267 st = SharedMemoryPubSubBase ._TopicState (ch , cap , None )
259268 self ._topics [topic ] = st
260269 return st
@@ -270,20 +279,32 @@ def _fanout_loop(self, topic: str, st: _TopicState) -> None:
270279 host = np .array (view , copy = True )
271280
272281 try :
273- L = struct .unpack ("<I" , host [:4 ].tobytes ())[0 ]
274- if L == 0 or L < 0 or L > st .capacity :
282+ # Read header: length(4) + pid(4) + counter(4)
283+ header = struct .unpack ("<III" , host [:12 ].tobytes ())
284+ L = header [0 ]
285+
286+ if L < 8 or L > st .capacity + 8 :
275287 continue
276288
277- payload = host [4 : 4 + L ].tobytes ()
289+ # Extract PID and counter
290+ pid = header [1 ]
291+ counter = header [2 ]
292+ message_id = (pid , counter )
293+
294+ # Extract actual payload (after removing the 8 bytes for pid+counter)
295+ payload_len = L - 8
296+ if payload_len > 0 :
297+ payload = host [12 : 12 + payload_len ].tobytes ()
298+ else :
299+ continue
278300
279301 # Drop exactly the number of local echoes we created
280- payload_hash = hashlib .md5 (payload ).digest ()
281- cnt = st .suppress_counts .get (payload_hash , 0 )
302+ cnt = st .suppress_counts .get (message_id , 0 )
282303 if cnt > 0 :
283304 if cnt == 1 :
284- del st .suppress_counts [payload_hash ]
305+ del st .suppress_counts [message_id ]
285306 else :
286- st .suppress_counts [payload_hash ] = cnt - 1
307+ st .suppress_counts [message_id ] = cnt - 1
287308 continue # suppressed
288309
289310 except Exception :
0 commit comments