@@ -231,6 +231,10 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False)
231231 # Determine output shape
232232 output_shape = MXNetGraph .infer_output_shape (sym , params , in_shape , output_label )
233233
234+ output_suffix = '_output'
235+ output_names = [
236+ o [:- len (output_suffix )] for o in sym .list_outputs () if o .endswith (output_suffix )]
237+
234238 weights = MXNetGraph .convert_weights_to_numpy (params )
235239
236240 mx_graph = json .loads (sym .tojson ())["nodes" ]
@@ -294,26 +298,25 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False)
294298 # If converted node is NodeProto, add it in processed nodes list
295299 elif isinstance (converted_node , NodeProto ):
296300 onnx_processed_nodes .append (converted_node )
297- if idx == (len (mx_graph ) - 1 ):
298- # If converted node doesnt have name, use it from output field
299- if not converted_node .name :
300- onnx_processed_outputs .append (
301- make_tensor_value_info (
302- name = converted_node .output [0 ],
303- elem_type = in_type ,
304- shape = output_shape
305- )
301+ # If converted node doesnt have name, use it from output field
302+ if not converted_node .name and idx == (len (mx_graph ) - 1 ):
303+ onnx_processed_outputs .append (
304+ make_tensor_value_info (
305+ name = converted_node .output [0 ],
306+ elem_type = in_type ,
307+ shape = output_shape
306308 )
307- else :
308- onnx_processed_outputs . append (
309- make_tensor_value_info (
310- name = converted_node . name ,
311- elem_type = in_type ,
312- shape = output_shape
313- )
309+ )
310+ elif converted_node . name in output_names :
311+ onnx_processed_outputs . append (
312+ make_tensor_value_info (
313+ name = converted_node . name ,
314+ elem_type = in_type ,
315+ shape = output_shape
314316 )
315- if verbose :
316- logging .info ("Output node is: %s" , converted_node .name )
317+ )
318+ if verbose :
319+ logging .info ("Output node is: %s" , converted_node .name )
317320 elif isinstance (converted_node , TensorProto ):
318321 raise ValueError ("Did not expect TensorProto" )
319322 else :
0 commit comments