1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ from concurrent .futures import ThreadPoolExecutor
1516import time
16- from typing import TypeVar
17+ from typing import Any
1718
1819from dimos import core
19- from dimos .core import DimosCluster , Module
20+ from dimos .core import DimosCluster
2021from dimos .core .global_config import GlobalConfig
22+ from dimos .core .module import Module , ModuleT
2123from dimos .core .resource import Resource
22-
23- T = TypeVar ( "T" , bound = "Module" )
24+ from dimos . core . rpc_client import RPCClient
25+ from dimos . core . worker_manager import WorkerManager
2426
2527
2628class ModuleCoordinator (Resource ):
27- _client : DimosCluster | None = None
29+ _client : DimosCluster | WorkerManager | None = None
30+ _global_config : GlobalConfig
2831 _n : int | None = None
2932 _memory_limit : str = "auto"
30- _deployed_modules : dict [type [Module ], Module ] = {}
33+ _deployed_modules : dict [type [Module ], RPCClient ] = {}
3134
3235 def __init__ (
3336 self ,
@@ -37,29 +40,55 @@ def __init__(
3740 cfg = global_config or GlobalConfig ()
3841 self ._n = n if n is not None else cfg .n_dask_workers
3942 self ._memory_limit = cfg .memory_limit
43+ self ._global_config = cfg
4044
4145 def start (self ) -> None :
42- self ._client = core .start (self ._n , self ._memory_limit )
46+ if self ._global_config .dask :
47+ self ._client = core .start (self ._n , self ._memory_limit )
48+ else :
49+ self ._client = WorkerManager ()
4350
4451 def stop (self ) -> None :
4552 for module in reversed (self ._deployed_modules .values ()):
4653 module .stop ()
4754
4855 self ._client .close_all () # type: ignore[union-attr]
4956
50- def deploy (self , module_class : type [T ], * args , ** kwargs ) -> T : # type: ignore[no-untyped-def]
57+ def deploy (self , module_class : type [ModuleT ], * args : Any , ** kwargs : Any ) -> RPCClient :
5158 if not self ._client :
5259 raise ValueError ("Not started" )
5360
54- module = self ._client .deploy (module_class , * args , ** kwargs ) # type: ignore[attr-defined ]
61+ module = self ._client .deploy (module_class , * args , ** kwargs ) # type: ignore[union-attr ]
5562 self ._deployed_modules [module_class ] = module
56- return module # type: ignore[no-any-return]
63+ return module
64+
65+ def deploy_parallel (
66+ self , module_specs : list [tuple [type [ModuleT ], tuple [Any , ...], dict [str , Any ]]]
67+ ) -> list [RPCClient ]:
68+ if not self ._client :
69+ raise ValueError ("Not started" )
70+
71+ if isinstance (self ._client , WorkerManager ):
72+ modules = self ._client .deploy_parallel (module_specs )
73+ for (module_class , _ , _ ), module in zip (module_specs , modules , strict = True ):
74+ self ._deployed_modules [module_class ] = module
75+ return modules # type: ignore[return-value]
76+ else :
77+ return [
78+ self .deploy (module_class , * args , ** kwargs )
79+ for module_class , args , kwargs in module_specs
80+ ]
5781
5882 def start_all_modules (self ) -> None :
59- for module in self ._deployed_modules .values ():
60- module .start ()
83+ modules = list (self ._deployed_modules .values ())
84+ if isinstance (self ._client , WorkerManager ):
85+ with ThreadPoolExecutor (max_workers = len (modules )) as executor :
86+ list (executor .map (lambda m : m .start (), modules ))
87+ else :
88+ for module in modules :
89+ module .start ()
6190
62- def get_instance (self , module : type [T ]) -> T | None :
91+ def get_instance (self , module : type [ModuleT ]) -> ModuleT | None :
6392 return self ._deployed_modules .get (module ) # type: ignore[return-value]
6493
6594 def loop (self ) -> None :
0 commit comments