1414# See the License for the specific language governing permissions and
1515# limitations under the License.
1616
17+ from collections .abc import Generator
1718import threading
1819import time
20+ from typing import Any
1921
2022import pytest
2123
2224from dimos .protocol .pubsub .benchmark .testdata import testdata
23- from dimos .protocol .pubsub .benchmark .type import BenchmarkResult , BenchmarkResults
25+ from dimos .protocol .pubsub .benchmark .type import (
26+ BenchmarkResult ,
27+ BenchmarkResults ,
28+ MsgGen ,
29+ PubSubContext ,
30+ TestCase ,
31+ )
2432
2533# Message sizes for throughput benchmarking (powers of 2 from 64B to 10MB)
2634MSG_SIZES = [
@@ -57,16 +65,16 @@ def size_id(size: int) -> str:
5765 return f"{ size } B"
5866
5967
60- def pubsub_id (testcase ) -> str :
68+ def pubsub_id (testcase : TestCase [ Any , Any ] ) -> str :
6169 """Extract pubsub implementation name from context manager function name."""
62- name = testcase .pubsub_context .__name__
70+ name : str = testcase .pubsub_context .__name__
6371 # Convert e.g. "lcm_pubsub_channel" -> "LCM", "memory_pubsub_channel" -> "Memory"
6472 prefix = name .replace ("_pubsub_channel" , "" ).replace ("_" , " " )
6573 return prefix .upper () if len (prefix ) <= 3 else prefix .title ().replace (" " , "" )
6674
6775
6876@pytest .fixture (scope = "module" )
69- def benchmark_results ():
77+ def benchmark_results () -> Generator [ BenchmarkResults , None , None ] :
7078 """Module-scoped fixture to collect benchmark results."""
7179 results = BenchmarkResults ()
7280 yield results
@@ -79,7 +87,12 @@ def benchmark_results():
7987@pytest .mark .tool
8088@pytest .mark .parametrize ("msg_size" , MSG_SIZES , ids = [size_id (s ) for s in MSG_SIZES ])
8189@pytest .mark .parametrize ("pubsub_context, msggen" , testdata , ids = [pubsub_id (t ) for t in testdata ])
82- def test_throughput (pubsub_context , msggen , msg_size , benchmark_results ):
90+ def test_throughput (
91+ pubsub_context : PubSubContext [Any , Any ],
92+ msggen : MsgGen [Any , Any ],
93+ msg_size : int ,
94+ benchmark_results : BenchmarkResults ,
95+ ) -> None :
8396 """Measure throughput for publishing and receiving messages over a fixed duration."""
8497 with pubsub_context () as pubsub :
8598 topic , msg = msggen (msg_size )
@@ -88,7 +101,7 @@ def test_throughput(pubsub_context, msggen, msg_size, benchmark_results):
88101 lock = threading .Lock ()
89102 all_received = threading .Event ()
90103
91- def callback (message , _topic ) :
104+ def callback (message : Any , _topic : Any ) -> None :
92105 nonlocal received_count
93106 with lock :
94107 received_count += 1
@@ -136,7 +149,10 @@ def callback(message, _topic):
136149 latency = latency_end - publish_end
137150
138151 # Record result (duration is publish time only for throughput calculation)
139- transport_name = pubsub_id (type ("TC" , (), {"pubsub_context" : pubsub_context })())
152+ # Extract transport name from context manager function name
153+ ctx_name = pubsub_context .__name__
154+ prefix = ctx_name .replace ("_pubsub_channel" , "" ).replace ("_" , " " )
155+ transport_name = prefix .upper () if len (prefix ) <= 3 else prefix .title ().replace (" " , "" )
140156 result = BenchmarkResult (
141157 transport = transport_name ,
142158 duration = publish_end - start ,
0 commit comments