|
21 | 21 | import threading |
22 | 22 | import time |
23 | 23 | from enum import Enum |
24 | | -from typing import Optional |
| 24 | +from typing import Callable, Optional |
25 | 25 |
|
26 | 26 | from dimos.core import Module, In, Out, rpc |
27 | 27 | from dimos.msgs.geometry_msgs import PoseStamped |
@@ -72,8 +72,9 @@ class BehaviorTreeNavigator(Module): |
72 | 72 |
|
73 | 73 | def __init__( |
74 | 74 | self, |
75 | | - local_planner: BaseLocalPlanner, |
76 | 75 | publishing_frequency: float = 1.0, |
| 76 | + reset_local_planner: Callable[[], None] = None, |
| 77 | + check_goal_reached: Callable[[], bool] = None, |
77 | 78 | **kwargs, |
78 | 79 | ): |
79 | 80 | """Initialize the Navigator. |
@@ -108,10 +109,13 @@ def __init__( |
108 | 109 | self.control_thread: Optional[threading.Thread] = None |
109 | 110 | self.stop_event = threading.Event() |
110 | 111 |
|
111 | | - self.local_planner = local_planner |
112 | 112 | # TF listener |
113 | 113 | self.tf = TF() |
114 | 114 |
|
| 115 | + # Local planner |
| 116 | + self.reset_local_planner = reset_local_planner |
| 117 | + self.check_goal_reached = check_goal_reached |
| 118 | + |
115 | 119 | # Recovery server for stuck detection |
116 | 120 | self.recovery_server = RecoveryServer(stuck_duration=5.0) |
117 | 121 |
|
@@ -284,7 +288,7 @@ def _control_loop(self): |
284 | 288 | self.cancel_goal() |
285 | 289 |
|
286 | 290 | # Check if goal is reached |
287 | | - if self.local_planner.is_goal_reached(): |
| 291 | + if self.check_goal_reached(): |
288 | 292 | reached_msg = Bool() |
289 | 293 | reached_msg.data = True |
290 | 294 | self.goal_reached.publish(reached_msg) |
@@ -317,7 +321,7 @@ def stop(self): |
317 | 321 | with self.state_lock: |
318 | 322 | self.state = NavigatorState.IDLE |
319 | 323 |
|
320 | | - self.local_planner.reset() |
| 324 | + self.reset_local_planner() |
321 | 325 | self.recovery_server.reset() # Reset recovery server when stopping |
322 | 326 |
|
323 | 327 | logger.info("Navigator stopped") |
0 commit comments