@@ -138,7 +138,7 @@ def load_from_mmar(
138138 # loading with `torch.load`
139139 model_dict = torch .load (model_file , map_location = map_location )
140140 if weights_only :
141- return model_dict .get ("model" , model_dict ) # model_dict["model" ] or model_dict directly
141+ return model_dict .get (model_key , model_dict ) # model_dict[model_key ] or model_dict directly
142142
143143 # 1. search `model_dict['train_config]` for model config spec.
144144 model_config = _get_val (dict (model_dict ).get ("train_conf" , {}), key = model_key , default = {})
@@ -168,7 +168,11 @@ def load_from_mmar(
168168 model_module , model_name = model_config .get ("path" , "." ).rsplit ("." , 1 )
169169 model_cls , has_cls = optional_import (module = model_module , name = model_name )
170170 if not has_cls :
171- raise ValueError (f"Could not load model config { model_config .get ('path' , '' )} ." )
171+ raise ValueError (
172+ f"Could not load MMAR model config { model_config .get ('path' , '' )} , "
173+ f"Please make sure MMAR's sub-folders in '{ model_dir } ' is on the PYTHONPATH."
174+ "See also: https://docs.nvidia.com/clara/clara-train-sdk/pt/byom.html"
175+ )
172176 else :
173177 raise ValueError (f"Could not load model config { model_config } ." )
174178
@@ -180,7 +184,7 @@ def load_from_mmar(
180184 else :
181185 model_inst = model_cls ()
182186 if pretrained :
183- model_inst .load_state_dict (model_dict .get ("model" , model_dict ))
187+ model_inst .load_state_dict (model_dict .get (model_key , model_dict ))
184188 print ("\n ---" )
185189 print (f"For more information, please visit { item [Keys .DOC ]} \n " )
186190 return model_inst
0 commit comments