Skip to content

Commit 5569418

Browse files
authored
Repair Stream (#932)
* stream changes undo, typing fixes * gps nav skill fix Former-commit-id: fd5a1d5
1 parent beece7c commit 5569418

3 files changed

Lines changed: 29 additions & 26 deletions

File tree

dimos/agents/skills/gps_nav_skill.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def set_gps_travel_points(self, *points: dict[str, float]) -> str:
8484
logger.info(f"Set travel points: {new_points}")
8585

8686
if self.gps_goal._transport is not None:
87-
self.gps_goal.publish(new_points)
87+
self.gps_goal.publish(new_points) # type: ignore[arg-type]
8888

8989
if self._set_gps_travel_goal_points:
9090
self._set_gps_travel_goal_points(new_points)

dimos/core/stream.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -82,14 +82,16 @@ class State(enum.Enum):
8282

8383
class Transport(Resource, ObservableMixin[T]):
8484
# used by local Output
85-
def broadcast(self, selfstream: Out[T], value: T) -> None: ...
85+
def broadcast(self, selfstream: Out[T], value: T) -> None:
86+
raise NotImplementedError
87+
88+
# used by local Input
89+
def subscribe(self, callback: Callable[[T], Any], selfstream: Stream[T]) -> Callable[[], None]:
90+
raise NotImplementedError
8691

8792
def publish(self, msg: T) -> None:
8893
self.broadcast(None, msg) # type: ignore[arg-type]
8994

90-
# used by local Input
91-
def subscribe(self, selfstream: In[T], callback: Callable[[T], any]) -> None: ... # type: ignore[valid-type]
92-
9395

9496
class Stream(Generic[T]):
9597
_transport: Transport | None # type: ignore[type-arg]
@@ -139,9 +141,11 @@ def __str__(self) -> str:
139141

140142
class Out(Stream[T], ObservableMixin[T]):
141143
_transport: Transport # type: ignore[type-arg]
144+
_subscribers: list[Callable[[T], Any]]
142145

143146
def __init__(self, *argv, **kwargs) -> None: # type: ignore[no-untyped-def]
144147
super().__init__(*argv, **kwargs)
148+
self._subscribers = []
145149

146150
@property
147151
def transport(self) -> Transport[T]:
@@ -168,22 +172,19 @@ def __reduce__(self): # type: ignore[no-untyped-def]
168172
),
169173
)
170174

171-
def publish(self, msg) -> None: # type: ignore[no-untyped-def]
172-
if not hasattr(self, "_transport") or self._transport is None:
173-
logger.warning(f"Trying to publish on Out {self} without a transport")
174-
return
175-
self._transport.broadcast(self, msg)
175+
def publish(self, msg: T) -> None:
176+
if hasattr(self, "_transport") and self._transport is not None:
177+
self._transport.broadcast(self, msg)
178+
for cb in self._subscribers:
179+
cb(msg)
176180

177-
def subscribe(self, cb) -> Callable[[], None]: # type: ignore[no-untyped-def]
178-
"""Subscribe to this output stream.
181+
def subscribe(self, cb: Callable[[T], Any]) -> Callable[[], None]:
182+
self._subscribers.append(cb)
179183

180-
Args:
181-
cb: Callback function to receive messages
184+
def unsubscribe() -> None:
185+
self._subscribers.remove(cb)
182186

183-
Returns:
184-
Unsubscribe function
185-
"""
186-
return self.transport.subscribe(cb, self) # type: ignore[arg-type, func-returns-value, no-any-return]
187+
return unsubscribe
187188

188189

189190
class RemoteStream(Stream[T]):
@@ -206,8 +207,8 @@ class RemoteOut(RemoteStream[T]):
206207
def connect(self, other: RemoteIn[T]): # type: ignore[no-untyped-def]
207208
return other.connect(self)
208209

209-
def subscribe(self, cb) -> Callable[[], None]: # type: ignore[no-untyped-def]
210-
return self.transport.subscribe(cb, self) # type: ignore[arg-type, func-returns-value, no-any-return]
210+
def subscribe(self, cb: Callable[[T], Any]) -> Callable[[], None]:
211+
return self.transport.subscribe(cb, self)
211212

212213

213214
# representation of Input
@@ -249,8 +250,8 @@ def state(self) -> State:
249250
return State.UNBOUND if self.owner is None else State.READY
250251

251252
# returns unsubscribe function
252-
def subscribe(self, cb) -> Callable[[], None]: # type: ignore[no-untyped-def]
253-
return self.transport.subscribe(cb, self) # type: ignore[arg-type, func-returns-value, no-any-return]
253+
def subscribe(self, cb: Callable[[T], Any]) -> Callable[[], None]:
254+
return self.transport.subscribe(cb, self)
254255

255256

256257
# representation of input outside of module

dimos/core/transport.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
TypeVar,
2626
)
2727

28-
from dimos.core.stream import In, Transport
28+
from dimos.core.stream import In, Out, Stream, Transport
2929
from dimos.protocol.pubsub.jpeg_shm import JpegSharedMemory
3030
from dimos.protocol.pubsub.lcmpubsub import LCM, JpegLCM, PickleLCM, Topic as LCMTopic
3131
from dimos.protocol.pubsub.shmpubsub import PickleSharedMemory, SharedMemory
@@ -60,18 +60,20 @@ def __init__(self, topic: str, **kwargs) -> None: # type: ignore[no-untyped-def
6060
def __reduce__(self): # type: ignore[no-untyped-def]
6161
return (pLCMTransport, (self.topic,))
6262

63-
def broadcast(self, _, msg) -> None: # type: ignore[no-untyped-def]
63+
def broadcast(self, _: Out[T] | None, msg: T) -> None:
6464
if not self._started:
6565
self.lcm.start()
6666
self._started = True
6767

6868
self.lcm.publish(self.topic, msg)
6969

70-
def subscribe(self, callback: Callable[[T], None], selfstream: In[T] = None) -> None: # type: ignore[assignment, override]
70+
def subscribe(
71+
self, callback: Callable[[T], Any], selfstream: Stream[T] | None = None
72+
) -> Callable[[], None]:
7173
if not self._started:
7274
self.lcm.start()
7375
self._started = True
74-
return self.lcm.subscribe(self.topic, lambda msg, topic: callback(msg)) # type: ignore[return-value]
76+
return self.lcm.subscribe(self.topic, lambda msg, topic: callback(msg))
7577

7678
def start(self) -> None: ...
7779

0 commit comments

Comments
 (0)