Skip to content

Commit 4662fbd

Browse files
authored
feat(replay): add --replay-dir to select replay dataset (#1519)
Adds --replay-dir flag to select which data/ directory to replay from: dimos --replay run ... # go2_sf_office (default) dimos --replay --replay-dir unitree_go2_bigoffice run ... # big office dataset dimos --replay --replay-dir <any_data_dir> run ... # any dataset Changes: - GlobalConfig: add replay_dir field (default 'go2_sf_office') - ReplayConnection: accept dataset param, use directly as data/ dir name - --replay flag unchanged (backward compatible)
1 parent df8f325 commit 4662fbd

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

dimos/core/global_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class GlobalConfig(BaseSettings):
3232
robot_ips: str | None = None
3333
simulation: bool = False
3434
replay: bool = False
35+
replay_dir: str = "go2_sf_office"
3536
new_memory: bool = False
3637
viewer: ViewerBackend = "rerun"
3738
n_workers: int = 2

dimos/robot/unitree/go2/connection.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,8 @@ def make_connection(ip: str | None, cfg: GlobalConfig) -> Go2ConnectionProtocol:
8787
connection_type = cfg.unitree_connection_type
8888

8989
if ip in ("fake", "mock", "replay") or connection_type == "replay":
90-
return ReplayConnection()
90+
dataset = cfg.replay_dir
91+
return ReplayConnection(dataset=dataset)
9192
elif ip == "mujoco" or connection_type == "mujoco":
9293
from dimos.robot.unitree.mujoco_connection import MujocoConnection
9394

@@ -98,13 +99,13 @@ def make_connection(ip: str | None, cfg: GlobalConfig) -> Go2ConnectionProtocol:
9899

99100

100101
class ReplayConnection(UnitreeWebRTCConnection):
101-
dir_name = "go2_sf_office"
102-
103102
# we don't want UnitreeWebRTCConnection to init
104103
def __init__( # type: ignore[no-untyped-def]
105104
self,
105+
dataset: str = "go2_sf_office",
106106
**kwargs,
107107
) -> None:
108+
self.dir_name = dataset
108109
get_data(self.dir_name)
109110
self.replay_config = {
110111
"loop": kwargs.get("loop", True),

0 commit comments

Comments
 (0)