Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 7898c5d

Browse files
committed
ONNX export: Logical operators
1 parent 0bea50e commit 7898c5d

File tree

3 files changed

+105
-1
lines changed

3 files changed

+105
-1
lines changed

python/mxnet/contrib/onnx/mx2onnx/_op_translations.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1545,3 +1545,31 @@ def convert_sum(node, **kwargs):
15451545
name=name
15461546
)
15471547
return [node]
1548+
1549+
@mx_op.register("broadcast_logical_and")
1550+
def convert_logical_and(node, **kwargs):
1551+
"""Map MXNet's logical and operator attributes to onnx's Add operator
1552+
and return the created node.
1553+
"""
1554+
return create_basic_op_node('And', node, kwargs)
1555+
1556+
@mx_op.register("broadcast_logical_or")
1557+
def convert_logical_or(node, **kwargs):
1558+
"""Map MXNet's logical or operator attributes to onnx's Or operator
1559+
and return the created node.
1560+
"""
1561+
return create_basic_op_node('Or', node, kwargs)
1562+
1563+
@mx_op.register("broadcast_logical_xor")
1564+
def convert_logical_xor(node, **kwargs):
1565+
"""Map MXNet's logical xor operator attributes to onnx's Xor operator
1566+
and return the created node.
1567+
"""
1568+
return create_basic_op_node('Xor', node, kwargs)
1569+
1570+
@mx_op.register("logical_not")
1571+
def convert_logical_not(node, **kwargs):
1572+
"""Map MXNet's logical not operator attributes to onnx's Not operator
1573+
and return the created node.
1574+
"""
1575+
return create_basic_op_node('Not', node, kwargs)

tests/python-pytest/onnx/export/mxnet_export_test.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,19 @@
5656
'https://s3.amazonaws.com/onnx-mxnet/model-zoo/inception_v2.tar.gz'
5757
}
5858

59+
def get_int_inputs(interval, shape):
60+
"""Helper to get integer input of given shape and range"""
61+
assert len(interval) == len(shape)
62+
inputs = []
63+
input_tensors = []
64+
for idx in range(len(interval)):
65+
low, high = interval[idx]
66+
inputs.append(np.random.randint(low, high, size=shape[idx]).astype("float32"))
67+
input_tensors.append(helper.make_tensor_value_info("input"+str(idx+1),
68+
TensorProto.FLOAT, shape=shape[idx]))
69+
70+
return inputs, input_tensors
71+
5972
def get_test_files(name):
6073
"""Extract tar file and returns model path and input, output data"""
6174
tar_name = download(URLS.get(name), dirname=CURR_PATH.__str__())
@@ -238,6 +251,70 @@ def test_square():
238251

239252
npt.assert_almost_equal(result, numpy_op)
240253

254+
@with_seed()
255+
def test_logical_and():
256+
"""Test for logical and in onnx operators."""
257+
inputs, input_tensor = get_int_inputs([(0, 2), (0, 2)], [(3, 4, 5), (3, 4, 5)])
258+
outputs = [helper.make_tensor_value_info("output", TensorProto.FLOAT, shape=np.shape(inputs[0]))]
259+
nodes = [helper.make_node("And", ["input1", "input2"], ["output"])]
260+
graph = helper.make_graph(nodes,
261+
"and_test",
262+
input_tensor,
263+
outputs)
264+
model = helper.make_model(graph)
265+
bkd_rep = backend.prepare(model)
266+
output = bkd_rep.run([inputs[0], inputs[1]])
267+
numpy_op = np.logical_and(inputs[0], inputs[1]).astype(np.float32)
268+
npt.assert_almost_equal(output[0], numpy_op)
269+
270+
@with_seed()
271+
def test_logical_or():
272+
"""Test for logical or in onnx operators."""
273+
inputs, input_tensor = get_int_inputs([(0, 2), (0, 2)], [(3, 4, 5), (3, 4, 5)])
274+
outputs = [helper.make_tensor_value_info("output", TensorProto.FLOAT, shape=np.shape(inputs[0]))]
275+
nodes = [helper.make_node("Or", ["input1", "input2"], ["output"])]
276+
graph = helper.make_graph(nodes,
277+
"or_test",
278+
input_tensor,
279+
outputs)
280+
model = helper.make_model(graph)
281+
bkd_rep = backend.prepare(model)
282+
output = bkd_rep.run([inputs[0], inputs[1]])
283+
numpy_op = np.logical_or(inputs[0], inputs[1]).astype(np.float32)
284+
npt.assert_almost_equal(output[0], numpy_op)
285+
286+
@with_seed()
287+
def test_logical_not():
288+
"""Test for logical not in onnx operators."""
289+
inputs, input_tensor = get_int_inputs([(0, 2)], [(3, 4, 5)])
290+
outputs = [helper.make_tensor_value_info("output", TensorProto.FLOAT, shape=np.shape(inputs[0]))]
291+
nodes = [helper.make_node("Not", ["input1"], ["output"])]
292+
graph = helper.make_graph(nodes,
293+
"not_test",
294+
input_tensor,
295+
outputs)
296+
model = helper.make_model(graph)
297+
bkd_rep = backend.prepare(model)
298+
output = bkd_rep.run([inputs[0]])
299+
numpy_op = np.logical_not(inputs[0]).astype(np.float32)
300+
npt.assert_almost_equal(output[0], numpy_op)
301+
302+
@with_seed()
303+
def test_logical_xor():
304+
"""Test for logical xor in onnx operators."""
305+
inputs, input_tensor = get_int_inputs([(0, 2), (0, 2)], [(3, 4, 5), (3, 4, 5)])
306+
outputs = [helper.make_tensor_value_info("output", TensorProto.FLOAT, shape=np.shape(inputs[0]))]
307+
nodes = [helper.make_node("Xor", ["input1", "input2"], ["output"])]
308+
graph = helper.make_graph(nodes,
309+
"xor_test",
310+
input_tensor,
311+
outputs)
312+
model = helper.make_model(graph)
313+
bkd_rep = backend.prepare(model)
314+
output = bkd_rep.run([inputs[0], inputs[1]])
315+
numpy_op = np.logical_xor(inputs[0], inputs[1]).astype(np.float32)
316+
npt.assert_almost_equal(output[0], numpy_op)
317+
241318
if __name__ == '__main__':
242319
test_models("bvlc_googlenet", (1, 3, 224, 224), (1, 1000))
243320
test_models("bvlc_reference_caffenet", (1, 3, 224, 224), (1, 1000))

tests/python-pytest/onnx/import/test_cases.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@
5555
'test_argmax',
5656
'test_argmin',
5757
'test_min',
58-
'test_logical_',
5958
# enabling partial test cases for matmul
6059
'test_matmul_3d',
6160
'test_matmul_4d',

0 commit comments

Comments
 (0)