Skip to content

Commit 1dd5519

Browse files
authored
Merge pull request #381 - Working LCM / Dask multiprocess with RPC
Multitree go2 Changes: - builds a general module structure, generic streaming / RPC between modules - Dask for module forking & comms stream configs - LCM pubsub / RPC between modules (supports pluggable pubsub/rpc transport protocols) - We support ros-style separate run files (unitree_webrtc/multiprocess_individual_node.py) - Hardware/Time expensive tests are now tagged with "heavy" - and run in parallel in a separate instance in CI - Example of unitree split into 6 modules / processes (unitree_webrtc/multiprocess_unitree_go2.py) Former-commit-id: cd86602 [formerly 7be4a88] Former-commit-id: d672e63
1 parent 30b6d31 commit 1dd5519

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+2266
-283
lines changed

.github/workflows/docker.yml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,3 +161,18 @@ jobs:
161161
}}
162162
cmd: "pytest"
163163
dev-image: dev:${{ needs.dev.result == 'success' && needs.check-changes.outputs.branch-tag || 'dev' }}
164+
165+
# we run in parallel with normal tests for speed
166+
run-heavy-tests:
167+
needs: [check-changes, dev]
168+
if: always()
169+
uses: ./.github/workflows/tests.yml
170+
with:
171+
should-run: ${{
172+
needs.check-changes.result == 'success' &&
173+
((needs.dev.result == 'success') ||
174+
(needs.dev.result == 'skipped' &&
175+
needs.check-changes.outputs.tests == 'true'))
176+
}}
177+
cmd: "pytest -m heavy"
178+
dev-image: dev:${{ needs.dev.result == 'success' && needs.check-changes.outputs.branch-tag || 'dev' }}

bin/lfs_push

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ for dir_path in data/*; do
6868
compressed_dirs+=("$dir_name")
6969

7070
# Add the compressed file to git LFS tracking
71-
git add "$compressed_file"
71+
git add -f "$compressed_file"
7272

7373
echo -e " ${GREEN}${NC} git-add $compressed_file"
7474

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:bee487130eb662bca73c7d84f14eaea091bd6d7c3f1bfd5173babf660947bdec
3+
size 553620791

dimos/agents/memory/test_image_embedding.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from dimos.stream.video_provider import VideoProvider
2929

3030

31+
@pytest.mark.heavy
3132
class TestImageEmbedding:
3233
"""Test class for CLIP image embedding functionality."""
3334

dimos/core/__init__.py

Lines changed: 80 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,82 @@
1+
from __future__ import annotations
2+
13
import multiprocessing as mp
4+
import time
5+
from typing import Optional
26

37
import pytest
48
from dask.distributed import Client, LocalCluster
9+
from rich.console import Console
510

611
import dimos.core.colors as colors
712
from dimos.core.core import In, Out, RemoteOut, rpc
8-
from dimos.core.module_dask import Module
13+
from dimos.core.module import Module, ModuleBase
914
from dimos.core.transport import LCMTransport, ZenohTransport, pLCMTransport
15+
from dimos.protocol.rpc.lcmrpc import LCMRPC
16+
from dimos.protocol.rpc.spec import RPC
17+
18+
19+
def patch_actor(actor, cls): ...
20+
21+
22+
class RPCClient:
23+
def __init__(self, actor_instance, actor_class):
24+
self.rpc = LCMRPC()
25+
self.actor_class = actor_class
26+
self.remote_name = actor_class.__name__
27+
self.actor_instance = actor_instance
28+
self.rpcs = actor_class.rpcs.keys()
29+
self.rpc.start()
30+
31+
def __reduce__(self):
32+
# Return the class and the arguments needed to reconstruct the object
33+
return (
34+
self.__class__,
35+
(self.actor_instance, self.actor_class),
36+
)
37+
38+
# passthrough
39+
def __getattr__(self, name: str):
40+
# Check if accessing a known safe attribute to avoid recursion
41+
if name in {
42+
"__class__",
43+
"__init__",
44+
"__dict__",
45+
"__getattr__",
46+
"rpcs",
47+
"remote_name",
48+
"remote_instance",
49+
"actor_instance",
50+
}:
51+
raise AttributeError(f"{name} is not found.")
52+
53+
if name in self.rpcs:
54+
return lambda *args: self.rpc.call_sync(f"{self.remote_name}/{name}", args)
55+
56+
# return super().__getattr__(name)
57+
# Try to avoid recursion by directly accessing attributes that are known
58+
return self.actor_instance.__getattr__(name)
1059

1160

1261
def patchdask(dask_client: Client):
13-
def deploy(actor_class, *args, **kwargs):
14-
actor = dask_client.submit(
15-
actor_class,
16-
*args,
17-
**kwargs,
18-
actor=True,
19-
).result()
20-
21-
actor.set_ref(actor).result()
22-
print(colors.green(f"Subsystem deployed: {actor}"))
23-
return actor
62+
def deploy(
63+
actor_class,
64+
*args,
65+
**kwargs,
66+
):
67+
console = Console()
68+
with console.status(f"deploying [green]{actor_class.__name__}", spinner="arc"):
69+
actor = dask_client.submit(
70+
actor_class,
71+
*args,
72+
**kwargs,
73+
actor=True,
74+
).result()
75+
76+
worker = actor.set_ref(actor).result()
77+
print((f"deployed: {colors.green(actor)} @ {colors.blue('worker ' + str(worker))}"))
78+
79+
return RPCClient(actor, actor_class)
2480

2581
dask_client.deploy = deploy
2682
return dask_client
@@ -34,15 +90,20 @@ def dimos():
3490
stop(client)
3591

3692

37-
def start(n):
93+
def start(n: Optional[int] = None) -> Client:
94+
console = Console()
3895
if not n:
3996
n = mp.cpu_count()
40-
print(colors.green(f"Initializing dimos local cluster with {n} workers"))
41-
cluster = LocalCluster(
42-
n_workers=n,
43-
threads_per_worker=3,
44-
)
45-
client = Client(cluster)
97+
with console.status(
98+
f"[green]Initializing dimos local cluster with [bright_blue]{n} workers", spinner="arc"
99+
) as status:
100+
cluster = LocalCluster(
101+
n_workers=n,
102+
threads_per_worker=4,
103+
)
104+
client = Client(cluster)
105+
106+
console.print(f"[green]Initialized dimos local cluster with [bright_blue]{n} workers")
46107
return patchdask(client)
47108

48109

dimos/core/core.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ class RemoteStream(Stream[T]):
187187
def state(self) -> State: # noqa: D401
188188
return State.UNBOUND if self.owner is None else State.READY
189189

190+
# this won't work but nvm
190191
@property
191192
def transport(self) -> Transport[T]:
192193
return self._transport
@@ -204,6 +205,7 @@ def connect(self, other: RemoteIn[T]):
204205

205206
class In(Stream[T]):
206207
connection: Optional[RemoteOut[T]] = None
208+
_transport: Transport
207209

208210
def __str__(self):
209211
mystr = super().__str__()
@@ -220,21 +222,35 @@ def __reduce__(self): # noqa: D401
220222

221223
@property
222224
def transport(self) -> Transport[T]:
223-
return self.connection.transport
225+
if not self._transport:
226+
self._transport = self.connection.transport
227+
return self._transport
224228

225229
@property
226230
def state(self) -> State: # noqa: D401
227231
return State.UNBOUND if self.owner is None else State.READY
228232

229233
def subscribe(self, cb):
230-
# print("SUBBING", self, self.connection._transport)
231-
self.connection._transport.subscribe(self, cb)
234+
self.transport.subscribe(self, cb)
232235

233236

234237
class RemoteIn(RemoteStream[T]):
235238
def connect(self, other: RemoteOut[T]) -> None:
236239
return self.owner.connect_stream(self.name, other).result()
237240

241+
# this won't work but that's ok
242+
@property
243+
def transport(self) -> Transport[T]:
244+
return self._transport
245+
246+
def publish(self, msg):
247+
self.transport.broadcast(self, msg)
248+
249+
@transport.setter
250+
def transport(self, value: Transport[T]) -> None:
251+
self.owner.set_transport(self.name, value).result()
252+
self._transport = value
253+
238254

239255
def rpc(fn: Callable[..., Any]) -> Callable[..., Any]:
240256
fn.__rpc__ = True # type: ignore[attr-defined]
Lines changed: 99 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -11,25 +11,112 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
14+
import inspect
1515
from typing import (
1616
Any,
1717
Callable,
18-
List,
1918
get_args,
2019
get_origin,
2120
get_type_hints,
2221
)
2322

24-
from dask.distributed import Actor
23+
from dask.distributed import Actor, get_worker
2524

25+
from dimos.core import colors
2626
from dimos.core.core import In, Out, RemoteIn, RemoteOut, T, Transport
27+
from dimos.protocol.rpc.lcmrpc import LCMRPC
28+
29+
30+
class ModuleBase:
31+
def __init__(self, *args, **kwargs):
32+
try:
33+
get_worker()
34+
self.rpc = LCMRPC()
35+
self.rpc.serve_module_rpc(self)
36+
self.rpc.start()
37+
except ValueError:
38+
return
39+
40+
@property
41+
def outputs(self) -> dict[str, Out]:
42+
return {
43+
name: s
44+
for name, s in self.__dict__.items()
45+
if isinstance(s, Out) and not name.startswith("_")
46+
}
47+
48+
@property
49+
def inputs(self) -> dict[str, In]:
50+
return {
51+
name: s
52+
for name, s in self.__dict__.items()
53+
if isinstance(s, In) and not name.startswith("_")
54+
}
55+
56+
@classmethod
57+
@property
58+
def rpcs(cls) -> dict[str, Callable]:
59+
return {
60+
name: getattr(cls, name)
61+
for name in dir(cls)
62+
if not name.startswith("_")
63+
and name != "rpcs" # Exclude the rpcs property itself to prevent recursion
64+
and callable(getattr(cls, name, None))
65+
and hasattr(getattr(cls, name), "__rpc__")
66+
}
67+
68+
def io(self) -> str:
69+
def _box(name: str) -> str:
70+
return [
71+
f"┌┴" + "─" * (len(name) + 1) + "┐",
72+
f"│ {name} │",
73+
f"└┬" + "─" * (len(name) + 1) + "┘",
74+
]
75+
76+
# can't modify __str__ on a function like we are doing for I/O
77+
# so we have a separate repr function here
78+
def repr_rpc(fn: Callable) -> str:
79+
sig = inspect.signature(fn)
80+
# Remove 'self' parameter
81+
params = [p for name, p in sig.parameters.items() if name != "self"]
82+
83+
# Format parameters with colored types
84+
param_strs = []
85+
for param in params:
86+
param_str = param.name
87+
if param.annotation != inspect.Parameter.empty:
88+
type_name = getattr(param.annotation, "__name__", str(param.annotation))
89+
param_str += ": " + colors.green(type_name)
90+
if param.default != inspect.Parameter.empty:
91+
param_str += f" = {param.default}"
92+
param_strs.append(param_str)
93+
94+
# Format return type
95+
return_annotation = ""
96+
if sig.return_annotation != inspect.Signature.empty:
97+
return_type = getattr(sig.return_annotation, "__name__", str(sig.return_annotation))
98+
return_annotation = " -> " + colors.green(return_type)
99+
100+
return (
101+
"RPC " + colors.blue(fn.__name__) + f"({', '.join(param_strs)})" + return_annotation
102+
)
103+
104+
ret = [
105+
*(f" ├─ {stream}" for stream in self.inputs.values()),
106+
*_box(self.__class__.__name__),
107+
*(f" ├─ {stream}" for stream in self.outputs.values()),
108+
" │",
109+
*(f" ├─ {repr_rpc(rpc)}" for rpc in self.rpcs.values()),
110+
]
27111

112+
return "\n".join(ret)
28113

29-
class Module:
114+
115+
class DaskModule(ModuleBase):
30116
ref: Actor
117+
worker: int
31118

32-
def __init__(self):
119+
def __init__(self, *args, **kwargs):
33120
self.ref = None
34121

35122
for name, ann in get_type_hints(self, include_extras=True).items():
@@ -42,9 +129,13 @@ def __init__(self):
42129
inner, *_ = get_args(ann) or (Any,)
43130
stream = In(inner, name, self)
44131
setattr(self, name, stream)
132+
super().__init__(*args, **kwargs)
45133

46-
def set_ref(self, ref):
134+
def set_ref(self, ref) -> int:
135+
worker = get_worker()
47136
self.ref = ref
137+
self.worker = worker.name
138+
return worker.name
48139

49140
def __str__(self):
50141
return f"{self.__class__.__name__}"
@@ -76,38 +167,6 @@ def dask_receive_msg(self, input_name: str, msg: Any):
76167
def dask_register_subscriber(self, output_name: str, subscriber: RemoteIn[T]):
77168
getattr(self, output_name).transport.dask_register_subscriber(subscriber)
78169

79-
@property
80-
def outputs(self) -> dict[str, Out]:
81-
return {
82-
name: s
83-
for name, s in self.__dict__.items()
84-
if isinstance(s, Out) and not name.startswith("_")
85-
}
86-
87-
@property
88-
def inputs(self) -> dict[str, In]:
89-
return {
90-
name: s
91-
for name, s in self.__dict__.items()
92-
if isinstance(s, In) and not name.startswith("_")
93-
}
94-
95-
@property
96-
def rpcs(self) -> List[Callable]:
97-
return [name for name in dir(self) if hasattr(getattr(self, name), "__rpc__")]
98170

99-
def io(self) -> str:
100-
def _box(name: str) -> str:
101-
return [
102-
"┌┴" + "─" * (len(name) + 1) + "┐",
103-
f"│ {name} │",
104-
"└┬" + "─" * (len(name) + 1) + "┘",
105-
]
106-
107-
ret = [
108-
*(f" ├─ {stream}" for stream in self.inputs.values()),
109-
*_box(self.__class__.__name__),
110-
*(f" ├─ {stream}" for stream in self.outputs.values()),
111-
]
112-
113-
return "\n".join(ret)
171+
# global setting
172+
Module = DaskModule

0 commit comments

Comments
 (0)