Skip to content

Commit 0a73435

Browse files
add replacement for dask (#1111)
* Still uses Dask by default. * Start without Dask with: ```bash uv run dimos --no-dask --simulation run unitree-go2-agentic ``` * Startup time improvement so far: 60 seconds before, 45 seconds now. (This is measured from running the command till I can instruct it to "start exploring" and it starts moving). * There are a lot of potential improvements we can do. This is the first step.
1 parent 9a12b06 commit 0a73435

11 files changed

Lines changed: 537 additions & 23 deletions

dimos/core/blueprints.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,12 +202,15 @@ def _verify_no_name_conflicts(self) -> None:
202202
def _deploy_all_modules(
203203
self, module_coordinator: ModuleCoordinator, global_config: GlobalConfig
204204
) -> None:
205+
module_specs: list[tuple[type[Module], tuple[Any, ...], dict[str, Any]]] = []
205206
for blueprint in self.blueprints:
206207
kwargs = {**blueprint.kwargs}
207208
sig = inspect.signature(blueprint.module.__init__)
208209
if "global_config" in sig.parameters:
209210
kwargs["global_config"] = global_config
210-
module_coordinator.deploy(blueprint.module, *blueprint.args, **kwargs)
211+
module_specs.append((blueprint.module, blueprint.args, kwargs))
212+
213+
module_coordinator.deploy_parallel(module_specs)
211214

212215
def _connect_transports(self, module_coordinator: ModuleCoordinator) -> None:
213216
# Gather all the In/Out connections with remapping applied.

dimos/core/global_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ class GlobalConfig(BaseSettings):
4848
robot_rotation_diameter: float = 0.6
4949
planner_strategy: NavigationStrategy = "simple"
5050
planner_robot_speed: float | None = None
51+
dask: bool = True
5152

5253
model_config = SettingsConfigDict(
5354
env_file=".env",

dimos/core/module.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,15 @@
2626
overload,
2727
)
2828

29+
from typing_extensions import TypeVar as TypeVarExtension
30+
2931
if TYPE_CHECKING:
3032
from dimos.core.introspection.module import ModuleInfo
3133

34+
from typing import TypeVar
35+
3236
from dask.distributed import Actor, get_worker
3337
from reactivex.disposable import CompositeDisposable
34-
from typing_extensions import TypeVar
3538

3639
from dimos.core import colors
3740
from dimos.core.core import T, rpc
@@ -82,7 +85,7 @@ class ModuleConfig:
8285
frame_id: str | None = None
8386

8487

85-
ModuleConfigT = TypeVar("ModuleConfigT", bound=ModuleConfig, default=ModuleConfig)
88+
ModuleConfigT = TypeVarExtension("ModuleConfigT", bound=ModuleConfig, default=ModuleConfig)
8689

8790

8891
class ModuleBase(Configurable[ModuleConfigT], SkillContainer, Resource):
@@ -355,7 +358,7 @@ def get_rpc_calls(self, *methods: str) -> RpcCall | tuple[RpcCall, ...]: # type
355358
return result[0] if len(result) == 1 else result
356359

357360

358-
class DaskModule(ModuleBase[ModuleConfigT]):
361+
class Module(ModuleBase[ModuleConfigT]):
359362
ref: Actor
360363
worker: int
361364

@@ -454,5 +457,4 @@ def dask_register_subscriber(self, output_name: str, subscriber: RemoteIn[T]) ->
454457
getattr(self, output_name).transport.dask_register_subscriber(subscriber)
455458

456459

457-
# global setting
458-
Module = DaskModule
460+
ModuleT = TypeVar("ModuleT", bound="Module")

dimos/core/module_coordinator.py

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,25 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from concurrent.futures import ThreadPoolExecutor
1516
import time
16-
from typing import TypeVar
17+
from typing import Any
1718

1819
from dimos import core
19-
from dimos.core import DimosCluster, Module
20+
from dimos.core import DimosCluster
2021
from dimos.core.global_config import GlobalConfig
22+
from dimos.core.module import Module, ModuleT
2123
from dimos.core.resource import Resource
22-
23-
T = TypeVar("T", bound="Module")
24+
from dimos.core.rpc_client import RPCClient
25+
from dimos.core.worker_manager import WorkerManager
2426

2527

2628
class ModuleCoordinator(Resource):
27-
_client: DimosCluster | None = None
29+
_client: DimosCluster | WorkerManager | None = None
30+
_global_config: GlobalConfig
2831
_n: int | None = None
2932
_memory_limit: str = "auto"
30-
_deployed_modules: dict[type[Module], Module] = {}
33+
_deployed_modules: dict[type[Module], RPCClient] = {}
3134

3235
def __init__(
3336
self,
@@ -37,29 +40,55 @@ def __init__(
3740
cfg = global_config or GlobalConfig()
3841
self._n = n if n is not None else cfg.n_dask_workers
3942
self._memory_limit = cfg.memory_limit
43+
self._global_config = cfg
4044

4145
def start(self) -> None:
42-
self._client = core.start(self._n, self._memory_limit)
46+
if self._global_config.dask:
47+
self._client = core.start(self._n, self._memory_limit)
48+
else:
49+
self._client = WorkerManager()
4350

4451
def stop(self) -> None:
4552
for module in reversed(self._deployed_modules.values()):
4653
module.stop()
4754

4855
self._client.close_all() # type: ignore[union-attr]
4956

50-
def deploy(self, module_class: type[T], *args, **kwargs) -> T: # type: ignore[no-untyped-def]
57+
def deploy(self, module_class: type[ModuleT], *args: Any, **kwargs: Any) -> RPCClient:
5158
if not self._client:
5259
raise ValueError("Not started")
5360

54-
module = self._client.deploy(module_class, *args, **kwargs) # type: ignore[attr-defined]
61+
module = self._client.deploy(module_class, *args, **kwargs) # type: ignore[union-attr]
5562
self._deployed_modules[module_class] = module
56-
return module # type: ignore[no-any-return]
63+
return module
64+
65+
def deploy_parallel(
66+
self, module_specs: list[tuple[type[ModuleT], tuple[Any, ...], dict[str, Any]]]
67+
) -> list[RPCClient]:
68+
if not self._client:
69+
raise ValueError("Not started")
70+
71+
if isinstance(self._client, WorkerManager):
72+
modules = self._client.deploy_parallel(module_specs)
73+
for (module_class, _, _), module in zip(module_specs, modules, strict=True):
74+
self._deployed_modules[module_class] = module
75+
return modules # type: ignore[return-value]
76+
else:
77+
return [
78+
self.deploy(module_class, *args, **kwargs)
79+
for module_class, args, kwargs in module_specs
80+
]
5781

5882
def start_all_modules(self) -> None:
59-
for module in self._deployed_modules.values():
60-
module.start()
83+
modules = list(self._deployed_modules.values())
84+
if isinstance(self._client, WorkerManager):
85+
with ThreadPoolExecutor(max_workers=len(modules)) as executor:
86+
list(executor.map(lambda m: m.start(), modules))
87+
else:
88+
for module in modules:
89+
module.start()
6190

62-
def get_instance(self, module: type[T]) -> T | None:
91+
def get_instance(self, module: type[ModuleT]) -> ModuleT | None:
6392
return self._deployed_modules.get(module) # type: ignore[return-value]
6493

6594
def loop(self) -> None:

dimos/core/test_modules.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,13 +89,10 @@ def is_module_subclass(
8989
target_classes = {
9090
"Module",
9191
"ModuleBase",
92-
"DaskModule",
9392
"dimos.core.Module",
9493
"dimos.core.ModuleBase",
95-
"dimos.core.DaskModule",
9694
"dimos.core.module.Module",
9795
"dimos.core.module.ModuleBase",
98-
"dimos.core.module.DaskModule",
9996
}
10097

10198
def find_qualified_name(base: str, context_module: str | None = None) -> str:
@@ -291,7 +288,7 @@ def get_all_module_subclasses():
291288
filtered_results = []
292289
for class_name, filepath, has_start, has_stop, forbidden_methods in results:
293290
# Skip base module classes themselves
294-
if class_name in ("Module", "ModuleBase", "DaskModule", "SkillModule"):
291+
if class_name in ("Module", "ModuleBase", "SkillModule"):
295292
continue
296293

297294
# Skip test-only modules (those defined in test_ files)

dimos/core/test_worker.py

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
# Copyright 2026 Dimensional Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytest
16+
17+
from dimos.core import In, Module, Out, rpc
18+
from dimos.core.worker_manager import WorkerManager
19+
from dimos.msgs.geometry_msgs import Vector3
20+
21+
22+
class SimpleModule(Module):
23+
output: Out[Vector3]
24+
input: In[Vector3]
25+
26+
counter: int = 0
27+
28+
@rpc
29+
def start(self) -> None:
30+
pass
31+
32+
@rpc
33+
def increment(self) -> int:
34+
self.counter += 1
35+
return self.counter
36+
37+
@rpc
38+
def get_counter(self) -> int:
39+
return self.counter
40+
41+
42+
class AnotherModule(Module):
43+
value: int = 100
44+
45+
@rpc
46+
def start(self) -> None:
47+
pass
48+
49+
@rpc
50+
def add(self, n: int) -> int:
51+
self.value += n
52+
return self.value
53+
54+
@rpc
55+
def get_value(self) -> int:
56+
return self.value
57+
58+
59+
class ThirdModule(Module):
60+
multiplier: int = 1
61+
62+
@rpc
63+
def start(self) -> None:
64+
pass
65+
66+
@rpc
67+
def multiply(self, n: int) -> int:
68+
self.multiplier *= n
69+
return self.multiplier
70+
71+
@rpc
72+
def get_multiplier(self) -> int:
73+
return self.multiplier
74+
75+
76+
@pytest.fixture
77+
def worker_manager():
78+
manager = WorkerManager()
79+
try:
80+
yield manager
81+
finally:
82+
manager.close_all()
83+
84+
85+
@pytest.mark.integration
86+
def test_worker_manager_basic(worker_manager):
87+
module = worker_manager.deploy(SimpleModule)
88+
module.start()
89+
90+
result = module.increment()
91+
assert result == 1
92+
93+
result = module.increment()
94+
assert result == 2
95+
96+
result = module.get_counter()
97+
assert result == 2
98+
99+
module.stop()
100+
101+
102+
@pytest.mark.integration
103+
def test_worker_manager_multiple_different_modules(worker_manager):
104+
module1 = worker_manager.deploy(SimpleModule)
105+
module2 = worker_manager.deploy(AnotherModule)
106+
107+
module1.start()
108+
module2.start()
109+
110+
# Each module has its own state
111+
module1.increment()
112+
module1.increment()
113+
module2.add(10)
114+
115+
assert module1.get_counter() == 2
116+
assert module2.get_value() == 110
117+
118+
# Stop modules to clean up threads
119+
module1.stop()
120+
module2.stop()
121+
122+
123+
@pytest.mark.integration
124+
def test_worker_manager_parallel_deployment(worker_manager):
125+
modules = worker_manager.deploy_parallel(
126+
[
127+
(SimpleModule, (), {}),
128+
(AnotherModule, (), {}),
129+
(ThirdModule, (), {}),
130+
]
131+
)
132+
133+
assert len(modules) == 3
134+
module1, module2, module3 = modules
135+
136+
# Start all modules
137+
module1.start()
138+
module2.start()
139+
module3.start()
140+
141+
# Each module has its own state
142+
module1.increment()
143+
module2.add(50)
144+
module3.multiply(5)
145+
146+
assert module1.get_counter() == 1
147+
assert module2.get_value() == 150
148+
assert module3.get_multiplier() == 5
149+
150+
# Stop modules
151+
module1.stop()
152+
module2.stop()
153+
module3.stop()

0 commit comments

Comments
 (0)