Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 5e632a0

Browse files
committed
ONNX export: Test for fully connected
1 parent c93fbc5 commit 5e632a0

File tree

1 file changed

+46
-0
lines changed

1 file changed

+46
-0
lines changed

tests/python-pytest/onnx/export/mxnet_export_test.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,52 @@ def test_square():
260260
npt.assert_almost_equal(result, numpy_op)
261261

262262

263+
@with_seed()
264+
def test_fully_connected():
265+
def random_arrays(*shapes):
266+
"""Generate some random numpy arrays."""
267+
arrays = [np.random.randn(*s).astype("float32")
268+
for s in shapes]
269+
if len(arrays) == 1:
270+
return arrays[0]
271+
return arrays
272+
273+
data_names = ['x', 'w', 'b']
274+
275+
dim_in, dim_out = (3, 4)
276+
input_data = random_arrays((4, dim_in), (dim_out, dim_in), (dim_out,))
277+
278+
ipsym = []
279+
data_shapes = []
280+
data_forward = []
281+
for idx in range(len(data_names)):
282+
val = input_data[idx]
283+
data_shapes.append((data_names[idx], np.shape(val)))
284+
data_forward.append(mx.nd.array(val))
285+
ipsym.append(mx.sym.Variable(data_names[idx]))
286+
287+
op = mx.sym.FullyConnected(data=ipsym[0], weight=ipsym[1], bias=ipsym[2], num_hidden=dim_out, name='FC')
288+
289+
model = mx.mod.Module(op, data_names=data_names, label_names=None)
290+
model.bind(for_training=False, data_shapes=data_shapes, label_shapes=None)
291+
292+
model.init_params()
293+
294+
args, auxs = model.get_params()
295+
params = {}
296+
params.update(args)
297+
params.update(auxs)
298+
299+
converted_model = onnx_mxnet.export_model(op, params, [shape[1] for shape in data_shapes], np.float32, "fc.onnx")
300+
301+
sym, arg_params, aux_params = onnx_mxnet.import_model(converted_model)
302+
result = forward_pass(sym, arg_params, aux_params, data_names, input_data)
303+
304+
numpy_op = np.dot(input_data[0], input_data[1].T) + input_data[2]
305+
306+
npt.assert_almost_equal(result, numpy_op)
307+
308+
263309
def test_softmax():
264310
input1 = np.random.rand(1000, 1000).astype("float32")
265311
label1 = np.random.rand(1000)

0 commit comments

Comments
 (0)