Skip to content

Commit 76c822f

Browse files
committed
massive improvements and bug fixes
1 parent 9e68120 commit 76c822f

File tree

9 files changed

+144
-40
lines changed

9 files changed

+144
-40
lines changed

dimos/manipulation/visual_servoing/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
compose_transforms,
3131
yaw_towards_point,
3232
get_distance,
33-
retract_distance,
33+
offset_distance,
3434
)
3535

3636

@@ -261,7 +261,7 @@ def update_target_grasp_pose(
261261
updated_pose = Pose(target_pos, target_orientation)
262262

263263
if grasp_distance > 0.0:
264-
return retract_distance(updated_pose, grasp_distance)
264+
return offset_distance(updated_pose, grasp_distance)
265265
else:
266266
return updated_pose
267267

dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from dimos.msgs.nav_msgs import OccupancyGrid, CostValues
3333
from dimos.utils.logging_config import setup_logger
3434
from dimos_lcm.std_msgs import Bool
35+
from dimos.utils.transform_utils import get_distance
3536

3637
logger = setup_logger("dimos.robot.unitree.frontier_exploration")
3738

@@ -100,7 +101,8 @@ def __init__(
100101
self,
101102
min_frontier_perimeter: float = 0.5,
102103
occupancy_threshold: int = 99,
103-
safe_distance: float = 2.0,
104+
safe_distance: float = 3.0,
105+
lookahead_distance: float = 5.0,
104106
max_explored_distance: float = 10.0,
105107
info_gain_threshold: float = 0.03,
106108
num_no_gain_attempts: int = 4,
@@ -122,6 +124,7 @@ def __init__(
122124
self.occupancy_threshold = occupancy_threshold
123125
self.safe_distance = safe_distance
124126
self.max_explored_distance = max_explored_distance
127+
self.lookahead_distance = lookahead_distance
125128
self.info_gain_threshold = info_gain_threshold
126129
self.num_no_gain_attempts = num_no_gain_attempts
127130
self._cache = FrontierCache()
@@ -496,35 +499,43 @@ def _compute_comprehensive_frontier_score(
496499
) -> float:
497500
"""Compute comprehensive score considering multiple criteria."""
498501

499-
# 1. Information gain (frontier size)
502+
# 1. Distance from robot (preference for moderate distances)
503+
robot_distance = get_distance(frontier, robot_pose)
504+
505+
# Distance score: prefer moderate distances (not too close, not too far)
506+
# Normalized to 0-1 range
507+
distance_score = 1.0 / (1.0 + abs(robot_distance - self.lookahead_distance))
508+
509+
# 2. Information gain (frontier size)
500510
# Normalize by a reasonable max frontier size
501511
max_expected_frontier_size = self.min_frontier_perimeter / costmap.resolution * 10
502512
info_gain_score = min(frontier_size / max_expected_frontier_size, 1.0)
503513

504-
# 2. Distance to explored goals (bonus for being far from explored areas)
514+
# 3. Distance to explored goals (bonus for being far from explored areas)
505515
# Normalize by a reasonable max distance (e.g., 10 meters)
506516
explored_goals_distance = self._compute_distance_to_explored_goals(frontier)
507517
explored_goals_score = min(explored_goals_distance / self.max_explored_distance, 1.0)
508518

509-
# 3. Distance to obstacles (score based on safety)
519+
# 4. Distance to obstacles (score based on safety)
510520
# 0 = too close to obstacles, 1 = at or beyond safe distance
511521
obstacles_distance = self._compute_distance_to_obstacles(frontier, costmap)
512522
if obstacles_distance >= self.safe_distance:
513523
obstacles_score = 1.0 # Fully safe
514524
else:
515525
obstacles_score = obstacles_distance / self.safe_distance # Linear penalty
516526

517-
# 4. Direction momentum (already in 0-1 range from dot product)
527+
# 5. Direction momentum (already in 0-1 range from dot product)
518528
momentum_score = self._compute_direction_momentum_score(frontier, robot_pose)
519529

520530
logger.info(
521-
f"Info gain score: {info_gain_score}, Explored goals score: {explored_goals_score}, Obstacles score: {obstacles_score}, Momentum score: {momentum_score}"
531+
f"Distance score: {distance_score:.2f}, Info gain: {info_gain_score:.2f}, Explored goals: {explored_goals_score:.2f}, Obstacles: {obstacles_score:.2f}, Momentum: {momentum_score:.2f}"
522532
)
523533

524-
# Combine scores with consistent scaling (no arbitrary multipliers)
534+
# Combine scores with consistent scaling
525535
total_score = (
526-
0.5 * info_gain_score # 30% information gain
536+
0.3 * info_gain_score # 30% information gain
527537
+ 0.3 * explored_goals_score # 30% distance from explored goals
538+
+ 0.2 * distance_score # 20% distance optimization
528539
+ 0.15 * obstacles_score # 15% distance from obstacles
529540
+ 0.05 * momentum_score # 5% direction momentum
530541
)

dimos/perception/object_tracker.py

Lines changed: 76 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,12 @@ def __init__(
9090
self.orb = cv2.ORB_create()
9191
self.bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=False)
9292
self.original_des = None # Store original ORB descriptors
93+
self.original_kps = None # Store original ORB keypoints
9394
self.reid_fail_count = 0 # Counter for consecutive re-id failures
95+
self.last_good_matches = [] # Store good matches for visualization
96+
self.last_roi_kps = None # Store last ROI keypoints for visualization
97+
self.last_roi_bbox = None # Store last ROI bbox for visualization
98+
self.reid_confirmed = False # Store current reid confirmation state
9499

95100
# For tracking latest frame data
96101
self._latest_rgb_frame: Optional[np.ndarray] = None
@@ -182,7 +187,7 @@ def track(
182187
# Extract initial features
183188
roi = self._latest_rgb_frame[y1:y2, x1:x2]
184189
if roi.size > 0:
185-
_, self.original_des = self.orb.detectAndCompute(roi, None)
190+
self.original_kps, self.original_des = self.orb.detectAndCompute(roi, None)
186191
if self.original_des is None:
187192
logger.warning("No ORB features found in initial ROI.")
188193
self.stop_track()
@@ -217,23 +222,31 @@ def reid(self, frame, current_bbox) -> bool:
217222
if roi.size == 0:
218223
return False # Empty ROI cannot match
219224

220-
_, des_current = self.orb.detectAndCompute(roi, None)
225+
kps_current, des_current = self.orb.detectAndCompute(roi, None)
221226
if des_current is None or len(des_current) < 2:
222227
return False # Need at least 2 descriptors for knnMatch
223228

229+
# Store ROI keypoints and bbox for visualization
230+
self.last_roi_kps = kps_current
231+
self.last_roi_bbox = [x1, y1, x2, y2]
232+
224233
# Handle case where original_des has only 1 descriptor (cannot use knnMatch with k=2)
225234
if len(self.original_des) < 2:
226235
matches = self.bf.match(self.original_des, des_current)
236+
self.last_good_matches = matches # Store all matches for visualization
227237
good_matches = len(matches)
228238
else:
229239
matches = self.bf.knnMatch(self.original_des, des_current, k=2)
230240
# Apply Lowe's ratio test robustly
241+
good_matches_list = []
231242
good_matches = 0
232243
for match_pair in matches:
233244
if len(match_pair) == 2:
234245
m, n = match_pair
235246
if m.distance < 0.75 * n.distance:
247+
good_matches_list.append(m)
236248
good_matches += 1
249+
self.last_good_matches = good_matches_list # Store good matches for visualization
237250

238251
return good_matches >= self.reid_threshold
239252

@@ -261,7 +274,12 @@ def _reset_tracking_state(self):
261274
self.tracking_bbox = None
262275
self.tracking_initialized = False
263276
self.original_des = None
277+
self.original_kps = None
264278
self.reid_fail_count = 0 # Reset counter
279+
self.last_good_matches = []
280+
self.last_roi_kps = None
281+
self.last_roi_bbox = None
282+
self.reid_confirmed = False # Reset reid confirmation state
265283

266284
# Publish empty detections to clear any visualizations
267285
empty_2d = Detection2DArray(detections_length=0, header=Header(), detections=[])
@@ -298,6 +316,16 @@ def stop_track(self) -> bool:
298316
logger.info("Tracking stopped")
299317
return True
300318

319+
@rpc
320+
def is_tracking(self) -> bool:
321+
"""
322+
Check if the tracker is currently tracking an object successfully.
323+
324+
Returns:
325+
bool: True if tracking is active and REID is confirmed, False otherwise
326+
"""
327+
return self.tracking_initialized and self.reid_confirmed
328+
301329
def _process_tracking(self):
302330
"""Process current frame for tracking and publish detections."""
303331
if self._latest_rgb_frame is None or self.tracker is None or not self.tracking_initialized:
@@ -316,11 +344,14 @@ def _process_tracking(self):
316344
current_bbox_x1y1x2y2 = [x, y, x + w, y + h]
317345
# Perform re-ID check
318346
reid_confirmed_this_frame = self.reid(frame, current_bbox_x1y1x2y2)
347+
self.reid_confirmed = reid_confirmed_this_frame # Store for is_tracking() RPC
319348

320349
if reid_confirmed_this_frame:
321350
self.reid_fail_count = 0
322351
else:
323352
self.reid_fail_count += 1
353+
else:
354+
self.reid_confirmed = False # No tracking if tracker failed
324355

325356
# Determine final success
326357
if tracker_succeeded:
@@ -480,10 +511,53 @@ def _process_tracking(self):
480511
self._latest_rgb_frame, detections_3d, show_coordinates=True, bboxes_2d=bbox_2d
481512
)
482513

514+
# Overlay REID feature matches if available
515+
if self.last_good_matches and self.last_roi_kps and self.last_roi_bbox:
516+
viz_image = self._draw_reid_matches(viz_image)
517+
483518
# Convert to Image message and publish
484519
viz_msg = Image.from_numpy(viz_image)
485520
self.tracked_overlay.publish(viz_msg)
486521

522+
def _draw_reid_matches(self, image: np.ndarray) -> np.ndarray:
523+
"""Draw REID feature matches on the image."""
524+
viz_image = image.copy()
525+
526+
x1, y1, x2, y2 = self.last_roi_bbox
527+
528+
# Draw keypoints from current ROI in green
529+
for kp in self.last_roi_kps:
530+
pt = (int(kp.pt[0] + x1), int(kp.pt[1] + y1))
531+
cv2.circle(viz_image, pt, 3, (0, 255, 0), -1)
532+
533+
for match in self.last_good_matches:
534+
current_kp = self.last_roi_kps[match.trainIdx]
535+
pt_current = (int(current_kp.pt[0] + x1), int(current_kp.pt[1] + y1))
536+
537+
# Draw a larger circle for matched points in yellow
538+
cv2.circle(viz_image, pt_current, 5, (0, 255, 255), 2) # Yellow for matched points
539+
540+
# Draw match strength indicator (smaller circle with intensity based on distance)
541+
# Lower distance = better match = brighter color
542+
intensity = int(255 * (1.0 - min(match.distance / 100.0, 1.0)))
543+
cv2.circle(viz_image, pt_current, 2, (intensity, intensity, 255), -1)
544+
545+
text = f"REID Matches: {len(self.last_good_matches)}/{len(self.last_roi_kps) if self.last_roi_kps else 0}"
546+
cv2.putText(viz_image, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2)
547+
548+
if len(self.last_good_matches) >= self.reid_threshold:
549+
status_text = "REID: CONFIRMED"
550+
status_color = (0, 255, 0) # Green
551+
else:
552+
status_text = f"REID: WEAK ({self.reid_fail_count}/{self.reid_fail_tolerance})"
553+
status_color = (0, 165, 255) # Orange
554+
555+
cv2.putText(
556+
viz_image, status_text, (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.7, status_color, 2
557+
)
558+
559+
return viz_image
560+
487561
def _get_depth_from_bbox(self, bbox: List[int]) -> Optional[float]:
488562
"""Calculate depth from bbox using the 25th percentile of closest points."""
489563
if self._latest_depth_frame is None:
@@ -504,8 +578,6 @@ def _get_depth_from_bbox(self, bbox: List[int]) -> Optional[float]:
504578
valid_depths = roi_depth[np.isfinite(roi_depth) & (roi_depth > 0)]
505579

506580
if len(valid_depths) > 0:
507-
# Take the 25th percentile of the closest (smallest) depth values
508-
# This helps get a robust depth estimate for the front surface of the object
509581
depth_25th_percentile = float(np.percentile(valid_depths, 25))
510582
return depth_25th_percentile
511583

dimos/perception/spatial_perception.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
from dimos.core import In, Module, Out, rpc
3030
from dimos.msgs.sensor_msgs import Image
3131
from dimos.msgs.geometry_msgs import Vector3, Quaternion, Pose, PoseStamped
32-
from dimos.robot.unitree_webrtc.type.odometry import Odometry
3332
from dimos.utils.logging_config import setup_logger
3433
from dimos.agents.memory.spatial_vector_db import SpatialVectorDB
3534
from dimos.agents.memory.image_embedding import ImageEmbeddingProvider
@@ -52,7 +51,7 @@ class SpatialMemory(Module):
5251

5352
# LCM inputs
5453
video: In[Image] = None
55-
odom: In[Odometry] = None
54+
odom: In[PoseStamped] = None
5655

5756
def __init__(
5857
self,
@@ -168,7 +167,7 @@ def __init__(
168167

169168
# Track latest data for processing
170169
self._latest_video_frame: Optional[np.ndarray] = None
171-
self._latest_odom: Optional[Odometry] = None
170+
self._latest_odom: Optional[PoseStamped] = None
172171
self._process_interval = 1
173172

174173
logger.info(f"SpatialMemory initialized with model {embedding_model}")
@@ -185,7 +184,7 @@ def set_video(image_msg: Image):
185184
else:
186185
logger.warning("Received image message without data attribute")
187186

188-
def set_odom(odom_msg: Odometry):
187+
def set_odom(odom_msg: PoseStamped):
189188
self._latest_odom = odom_msg
190189

191190
self.video.subscribe(set_video)

dimos/robot/unitree_webrtc/camera_module.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,10 @@ class UnitreeCameraModule(Module):
6161
def __init__(
6262
self,
6363
camera_intrinsics: List[float],
64+
world_frame_id: str = "world",
6465
camera_frame_id: str = "camera_link",
6566
base_frame_id: str = "base_link",
66-
gt_depth_scale: float = 2.2,
67+
gt_depth_scale: float = 2.0,
6768
**kwargs,
6869
):
6970
"""
@@ -82,6 +83,7 @@ def __init__(
8283
self.camera_intrinsics = camera_intrinsics
8384
self.camera_frame_id = camera_frame_id
8485
self.base_frame_id = base_frame_id
86+
self.world_frame_id = world_frame_id
8587

8688
# Initialize components
8789
from dimos.models.depth.metric3d import Metric3D
@@ -296,7 +298,7 @@ def _publish_camera_pose(self, header: Header):
296298
try:
297299
# Look up transform from base_link to camera_link
298300
transform = self.tf.get(
299-
parent_frame=self.base_frame_id,
301+
parent_frame=self.world_frame_id,
300302
child_frame=self.camera_frame_id,
301303
time_point=header.ts,
302304
time_tolerance=1.0,
@@ -306,7 +308,7 @@ def _publish_camera_pose(self, header: Header):
306308
# Create PoseStamped from transform
307309
pose_msg = PoseStamped(
308310
ts=header.ts,
309-
frame_id=self.base_frame_id,
311+
frame_id=self.camera_frame_id,
310312
position=transform.translation,
311313
orientation=transform.rotation,
312314
)

dimos/robot/unitree_webrtc/unitree_go2.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
from dimos.utils.data import get_data
5151
from dimos.utils.logging_config import setup_logger
5252
from dimos.utils.testing import TimedSensorReplay
53-
from dimos.utils.transform_utils import retract_distance
53+
from dimos.utils.transform_utils import offset_distance
5454
from dimos.perception.common.utils import extract_pose_from_detection3d
5555
from dimos.perception.object_tracker import ObjectTracking
5656
from dimos_lcm.std_msgs import Bool
@@ -361,8 +361,10 @@ def _deploy_perception(self):
361361
output_dir=self.spatial_memory_dir,
362362
)
363363

364-
self.spatial_memory_module.video.connect(self.connection.video)
365-
self.spatial_memory_module.odom.connect(self.connection.odom)
364+
self.spatial_memory_module.video.transport = core.LCMTransport("/go2/color_image", Image)
365+
self.spatial_memory_module.odom.transport = core.LCMTransport(
366+
"/go2/camera_pose", PoseStamped
367+
)
366368

367369
logger.info("Spatial memory module deployed and connected")
368370

@@ -531,7 +533,7 @@ def get_odom(self) -> PoseStamped:
531533
"""
532534
return self.connection.get_odom()
533535

534-
def navigate_to_object(self, bbox: List[float], distance: float, timeout: float = 30.0):
536+
def navigate_to_object(self, bbox: List[float], distance: float = 0.5, timeout: float = 30.0):
535537
"""Navigate to an object by tracking it and maintaining a specified distance.
536538
537539
Args:
@@ -563,13 +565,18 @@ def navigate_to_object(self, bbox: List[float], distance: float, timeout: float
563565
logger.info("Object tracking goal reached")
564566
return True
565567

568+
if not self.object_tracker.is_tracking():
569+
continue
570+
566571
detection_topic = Topic("/go2/detection3d", Detection3DArray)
567572
detection_msg = self.lcm.wait_for_message(detection_topic, timeout=1.0)
568573

569574
if detection_msg and len(detection_msg.detections) > 0:
570575
target_pose = extract_pose_from_detection3d(detection_msg.detections[0])
571576

572-
retracted_pose = retract_distance(target_pose, distance)
577+
retracted_pose = offset_distance(
578+
target_pose, distance, approach_vector=Vector3(-1, 0, 0)
579+
)
573580

574581
goal_pose = PoseStamped(
575582
frame_id=detection_msg.header.frame_id,
@@ -579,7 +586,7 @@ def navigate_to_object(self, bbox: List[float], distance: float, timeout: float
579586
self.navigator.set_goal(goal_pose)
580587
goal_set = True
581588

582-
time.sleep(0.3)
589+
time.sleep(0.25)
583590

584591
logger.info("Object tracking timed out")
585592
return False

0 commit comments

Comments
 (0)