Skip to content

Commit c850d3a

Browse files
leshyclaude
andcommitted
Fix race conditions in ROS pubsub and benchmark test
- Capture executor/node references in local variables before use to prevent NoneType errors when stop() is called concurrently - Replace dynamic class creation with direct string manipulation for transport name extraction Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 3950e01 commit c850d3a

File tree

2 files changed

+38
-15
lines changed

2 files changed

+38
-15
lines changed

dimos/protocol/pubsub/benchmark/test_benchmark.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,21 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17+
from collections.abc import Generator
1718
import threading
1819
import time
20+
from typing import Any
1921

2022
import pytest
2123

2224
from dimos.protocol.pubsub.benchmark.testdata import testdata
23-
from dimos.protocol.pubsub.benchmark.type import BenchmarkResult, BenchmarkResults
25+
from dimos.protocol.pubsub.benchmark.type import (
26+
BenchmarkResult,
27+
BenchmarkResults,
28+
MsgGen,
29+
PubSubContext,
30+
TestCase,
31+
)
2432

2533
# Message sizes for throughput benchmarking (powers of 2 from 64B to 10MB)
2634
MSG_SIZES = [
@@ -57,16 +65,16 @@ def size_id(size: int) -> str:
5765
return f"{size}B"
5866

5967

60-
def pubsub_id(testcase) -> str:
68+
def pubsub_id(testcase: TestCase[Any, Any]) -> str:
6169
"""Extract pubsub implementation name from context manager function name."""
62-
name = testcase.pubsub_context.__name__
70+
name: str = testcase.pubsub_context.__name__
6371
# Convert e.g. "lcm_pubsub_channel" -> "LCM", "memory_pubsub_channel" -> "Memory"
6472
prefix = name.replace("_pubsub_channel", "").replace("_", " ")
6573
return prefix.upper() if len(prefix) <= 3 else prefix.title().replace(" ", "")
6674

6775

6876
@pytest.fixture(scope="module")
69-
def benchmark_results():
77+
def benchmark_results() -> Generator[BenchmarkResults, None, None]:
7078
"""Module-scoped fixture to collect benchmark results."""
7179
results = BenchmarkResults()
7280
yield results
@@ -79,7 +87,12 @@ def benchmark_results():
7987
@pytest.mark.tool
8088
@pytest.mark.parametrize("msg_size", MSG_SIZES, ids=[size_id(s) for s in MSG_SIZES])
8189
@pytest.mark.parametrize("pubsub_context, msggen", testdata, ids=[pubsub_id(t) for t in testdata])
82-
def test_throughput(pubsub_context, msggen, msg_size, benchmark_results):
90+
def test_throughput(
91+
pubsub_context: PubSubContext[Any, Any],
92+
msggen: MsgGen[Any, Any],
93+
msg_size: int,
94+
benchmark_results: BenchmarkResults,
95+
) -> None:
8396
"""Measure throughput for publishing and receiving messages over a fixed duration."""
8497
with pubsub_context() as pubsub:
8598
topic, msg = msggen(msg_size)
@@ -88,7 +101,7 @@ def test_throughput(pubsub_context, msggen, msg_size, benchmark_results):
88101
lock = threading.Lock()
89102
all_received = threading.Event()
90103

91-
def callback(message, _topic):
104+
def callback(message: Any, _topic: Any) -> None:
92105
nonlocal received_count
93106
with lock:
94107
received_count += 1
@@ -136,7 +149,10 @@ def callback(message, _topic):
136149
latency = latency_end - publish_end
137150

138151
# Record result (duration is publish time only for throughput calculation)
139-
transport_name = pubsub_id(type("TC", (), {"pubsub_context": pubsub_context})())
152+
# Extract transport name from context manager function name
153+
ctx_name = pubsub_context.__name__
154+
prefix = ctx_name.replace("_pubsub_channel", "").replace("_", " ")
155+
transport_name = prefix.upper() if len(prefix) <= 3 else prefix.title().replace(" ", "")
140156
result = BenchmarkResult(
141157
transport=transport_name,
142158
duration=publish_end - start,

dimos/protocol/pubsub/rospubsub.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -152,17 +152,24 @@ def stop(self) -> None:
152152

153153
def _spin(self) -> None:
154154
"""Background thread for spinning the ROS executor."""
155-
while self._running and self._executor:
156-
self._executor.spin_once(timeout_sec=0) # Non-blocking for max throughput
155+
while self._running:
156+
executor = self._executor
157+
if executor is None:
158+
break
159+
executor.spin_once(timeout_sec=0) # Non-blocking for max throughput
157160

158161
def _get_or_create_publisher(self, topic: ROSTopic) -> Any:
159162
"""Get existing publisher or create a new one."""
160-
if topic.topic not in self._publishers:
161-
qos = topic.qos if topic.qos is not None else self._qos
162-
self._publishers[topic.topic] = self._node.create_publisher(
163-
topic.ros_type, topic.topic, qos
164-
)
165-
return self._publishers[topic.topic]
163+
with self._lock:
164+
if topic.topic not in self._publishers:
165+
node = self._node
166+
if node is None:
167+
raise RuntimeError("Pubsub must be started before publishing")
168+
qos = topic.qos if topic.qos is not None else self._qos
169+
self._publishers[topic.topic] = node.create_publisher(
170+
topic.ros_type, topic.topic, qos
171+
)
172+
return self._publishers[topic.topic]
166173

167174
def publish(self, topic: ROSTopic, message: Any) -> None:
168175
"""Publish a message to a ROS topic.

0 commit comments

Comments
 (0)