Skip to content

Commit c2ebe78

Browse files
committed
update based on comments
Signed-off-by: Wenqi Li <wenqil@nvidia.com>
1 parent fc40e4e commit c2ebe78

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

monai/apps/mmars/mmars.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def load_from_mmar(
138138
# loading with `torch.load`
139139
model_dict = torch.load(model_file, map_location=map_location)
140140
if weights_only:
141-
return model_dict.get("model", model_dict) # model_dict["model"] or model_dict directly
141+
return model_dict.get(model_key, model_dict) # model_dict[model_key] or model_dict directly
142142

143143
# 1. search `model_dict['train_config]` for model config spec.
144144
model_config = _get_val(dict(model_dict).get("train_conf", {}), key=model_key, default={})
@@ -168,7 +168,11 @@ def load_from_mmar(
168168
model_module, model_name = model_config.get("path", ".").rsplit(".", 1)
169169
model_cls, has_cls = optional_import(module=model_module, name=model_name)
170170
if not has_cls:
171-
raise ValueError(f"Could not load model config {model_config.get('path', '')}.")
171+
raise ValueError(
172+
f"Could not load MMAR model config {model_config.get('path', '')}, "
173+
f"Please make sure MMAR's sub-folders in '{model_dir}' is on the PYTHONPATH."
174+
"See also: https://docs.nvidia.com/clara/clara-train-sdk/pt/byom.html"
175+
)
172176
else:
173177
raise ValueError(f"Could not load model config {model_config}.")
174178

@@ -180,7 +184,7 @@ def load_from_mmar(
180184
else:
181185
model_inst = model_cls()
182186
if pretrained:
183-
model_inst.load_state_dict(model_dict.get("model", model_dict))
187+
model_inst.load_state_dict(model_dict.get(model_key, model_dict))
184188
print("\n---")
185189
print(f"For more information, please visit {item[Keys.DOC]}\n")
186190
return model_inst

0 commit comments

Comments
 (0)