Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit a3f8e51

Browse files
committed
ONNX export: Logical operators
1 parent 38eeb0c commit a3f8e51

File tree

3 files changed

+70
-1
lines changed

3 files changed

+70
-1
lines changed

python/mxnet/contrib/onnx/mx2onnx/_op_translations.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1545,3 +1545,31 @@ def convert_sum(node, **kwargs):
15451545
name=name
15461546
)
15471547
return [node]
1548+
1549+
@mx_op.register("broadcast_logical_and")
1550+
def convert_broadcast_logical_and(node, **kwargs):
1551+
"""Map MXNet's broadcast logical and operator attributes to onnx's Add operator
1552+
and return the created node.
1553+
"""
1554+
return create_basic_op_node('And', node, kwargs)
1555+
1556+
@mx_op.register("broadcast_logical_or")
1557+
def convert_broadcast_logical_or(node, **kwargs):
1558+
"""Map MXNet's broadcast logical or operator attributes to onnx's Or operator
1559+
and return the created node.
1560+
"""
1561+
return create_basic_op_node('Or', node, kwargs)
1562+
1563+
@mx_op.register("broadcast_logical_xor")
1564+
def convert_broadcast_logical_xor(node, **kwargs):
1565+
"""Map MXNet's broadcast logical xor operator attributes to onnx's Xor operator
1566+
and return the created node.
1567+
"""
1568+
return create_basic_op_node('Xor', node, kwargs)
1569+
1570+
@mx_op.register("logical_not")
1571+
def convert_logical_not(node, **kwargs):
1572+
"""Map MXNet's logical not operator attributes to onnx's Not operator
1573+
and return the created node.
1574+
"""
1575+
return create_basic_op_node('Not', node, kwargs)

tests/python-pytest/onnx/export/mxnet_export_test.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,48 @@ def test_square():
238238

239239
npt.assert_almost_equal(result, numpy_op)
240240

241+
242+
def get_int_inputs(interval, shape):
243+
"""Helper to get integer input of given shape and range"""
244+
assert len(interval) == len(shape)
245+
inputs = []
246+
input_tensors = []
247+
for idx in range(len(interval)):
248+
low, high = interval[idx]
249+
inputs.append(np.random.randint(low, high, size=shape[idx]).astype("float32"))
250+
input_tensors.append(helper.make_tensor_value_info("input"+str(idx+1),
251+
TensorProto.FLOAT, shape=shape[idx]))
252+
253+
return inputs, input_tensors
254+
255+
256+
@with_seed()
257+
def test_logical_ops():
258+
"""Test for logical and, or, not, xor operators"""
259+
def test_ops(op_name, inputs, input_tensors, numpy_op):
260+
outputs = [helper.make_tensor_value_info("output", TensorProto.FLOAT, shape=np.shape(inputs[0]))]
261+
262+
nodes = [helper.make_node(op_name, ["input"+str(i+1) for i in range(len(inputs))], ["output"])]
263+
graph = helper.make_graph(nodes,
264+
op_name + "_test",
265+
input_tensors,
266+
outputs)
267+
model = helper.make_model(graph)
268+
bkd_rep = backend.prepare(model)
269+
output = bkd_rep.run(inputs)
270+
npt.assert_almost_equal(output[0], numpy_op)
271+
272+
input_data, input_tensor = get_int_inputs([(0, 2), (0, 2)], [(3, 4, 5), (3, 4, 5)])
273+
test_ops("And", input_data, input_tensor,
274+
np.logical_and(input_data[0], input_data[1]).astype(np.float32))
275+
test_ops("Or", input_data, input_tensor,
276+
np.logical_or(input_data[0], input_data[1]).astype(np.float32))
277+
test_ops("Xor", input_data, input_tensor,
278+
np.logical_xor(input_data[0], input_data[1]).astype(np.float32))
279+
test_ops("Not", [input_data[0]], [input_tensor[0]],
280+
np.logical_not(input_data[0]).astype(np.float32))
281+
282+
241283
if __name__ == '__main__':
242284
test_models("bvlc_googlenet", (1, 3, 224, 224), (1, 1000))
243285
test_models("bvlc_reference_caffenet", (1, 3, 224, 224), (1, 1000))

tests/python-pytest/onnx/import/test_cases.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@
5555
'test_argmax',
5656
'test_argmin',
5757
'test_min',
58-
'test_logical_',
5958
# enabling partial test cases for matmul
6059
'test_matmul_3d',
6160
'test_matmul_4d',

0 commit comments

Comments
 (0)