Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 372cb0a

Browse files
committed
ONNX export: Add Flatten before Gemm
1 parent f838bb5 commit 372cb0a

File tree

2 files changed

+32
-18
lines changed

2 files changed

+32
-18
lines changed

python/mxnet/contrib/onnx/mx2onnx/_op_translations.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,17 @@ def convert_fully_connected(node, **kwargs):
232232

233233
fcnode = []
234234

235+
op_name = "flatten_" + str(kwargs["idx"])
236+
flatten_node = onnx.helper.make_node(
237+
'Flatten',
238+
inputs=[input_nodes[0]],
239+
outputs=[op_name],
240+
name=op_name
241+
)
242+
243+
input_nodes[0] = op_name
244+
fcnode.append(flatten_node)
245+
235246
if no_bias:
236247
data_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype('int64')]
237248
bias_name = "bias" + str(kwargs["idx"])

python/mxnet/contrib/onnx/mx2onnx/export_onnx.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,10 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False)
231231
# Determine output shape
232232
output_shape = MXNetGraph.infer_output_shape(sym, params, in_shape, output_label)
233233

234+
output_suffix = '_output'
235+
output_names = [
236+
o[:-len(output_suffix)] for o in sym.list_outputs() if o.endswith(output_suffix)]
237+
234238
weights = MXNetGraph.convert_weights_to_numpy(params)
235239

236240
mx_graph = json.loads(sym.tojson())["nodes"]
@@ -294,26 +298,25 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False)
294298
# If converted node is NodeProto, add it in processed nodes list
295299
elif isinstance(converted_node, NodeProto):
296300
onnx_processed_nodes.append(converted_node)
297-
if idx == (len(mx_graph) - 1):
298-
# If converted node doesnt have name, use it from output field
299-
if not converted_node.name:
300-
onnx_processed_outputs.append(
301-
make_tensor_value_info(
302-
name=converted_node.output[0],
303-
elem_type=in_type,
304-
shape=output_shape
305-
)
301+
# If converted node doesnt have name, use it from output field
302+
if not converted_node.name and idx == (len(mx_graph) - 1):
303+
onnx_processed_outputs.append(
304+
make_tensor_value_info(
305+
name=converted_node.output[0],
306+
elem_type=in_type,
307+
shape=output_shape
306308
)
307-
else:
308-
onnx_processed_outputs.append(
309-
make_tensor_value_info(
310-
name=converted_node.name,
311-
elem_type=in_type,
312-
shape=output_shape
313-
)
309+
)
310+
elif converted_node.name in output_names:
311+
onnx_processed_outputs.append(
312+
make_tensor_value_info(
313+
name=converted_node.name,
314+
elem_type=in_type,
315+
shape=output_shape
314316
)
315-
if verbose:
316-
logging.info("Output node is: %s", converted_node.name)
317+
)
318+
if verbose:
319+
logging.info("Output node is: %s", converted_node.name)
317320
elif isinstance(converted_node, TensorProto):
318321
raise ValueError("Did not expect TensorProto")
319322
else:

0 commit comments

Comments
 (0)