Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 2bc1dae

Browse files
committed
ONNX export: Logical operators
1 parent be9ca1b commit 2bc1dae

File tree

3 files changed

+174
-1
lines changed

3 files changed

+174
-1
lines changed

python/mxnet/contrib/onnx/mx2onnx/_op_translations.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2265,3 +2265,100 @@ def convert_sum(node, **kwargs):
22652265
name=name
22662266
)
22672267
return [node]
2268+
2269+
@mx_op.register("broadcast_logical_and")
2270+
def convert_logical_and(node, **kwargs):
2271+
"""Map MXNet's logical and operator attributes to onnx's Add operator
2272+
and return the created node.
2273+
"""
2274+
onnx = import_onnx_modules()
2275+
name = node["name"]
2276+
proc_nodes = kwargs["proc_nodes"]
2277+
inputs = node["inputs"]
2278+
2279+
input_node_a_id = kwargs["index_lookup"][inputs[0][0]]
2280+
input_node_b_id = kwargs["index_lookup"][inputs[1][0]]
2281+
2282+
input_node_a = proc_nodes[input_node_a_id].name
2283+
input_node_b = proc_nodes[input_node_b_id].name
2284+
2285+
and_node = onnx.helper.make_node(
2286+
"And",
2287+
[input_node_a, input_node_b],
2288+
[name],
2289+
name=name,
2290+
)
2291+
2292+
return [and_node]
2293+
2294+
@mx_op.register("broadcast_logical_or")
2295+
def convert_logical_or(node, **kwargs):
2296+
"""Map MXNet's logical or operator attributes to onnx's Or operator
2297+
and return the created node.
2298+
"""
2299+
onnx = import_onnx_modules()
2300+
name = node["name"]
2301+
proc_nodes = kwargs["proc_nodes"]
2302+
inputs = node["inputs"]
2303+
2304+
input_node_a_id = kwargs["index_lookup"][inputs[0][0]]
2305+
input_node_b_id = kwargs["index_lookup"][inputs[1][0]]
2306+
2307+
input_node_a = proc_nodes[input_node_a_id].name
2308+
input_node_b = proc_nodes[input_node_b_id].name
2309+
2310+
or_node = onnx.helper.make_node(
2311+
"Or",
2312+
[input_node_a, input_node_b],
2313+
[name],
2314+
name=name,
2315+
)
2316+
2317+
return [or_node]
2318+
2319+
@mx_op.register("broadcast_logical_xor")
2320+
def convert_logical_xor(node, **kwargs):
2321+
"""Map MXNet's logical xor operator attributes to onnx's Xor operator
2322+
and return the created node.
2323+
"""
2324+
onnx = import_onnx_modules()
2325+
name = node["name"]
2326+
proc_nodes = kwargs["proc_nodes"]
2327+
inputs = node["inputs"]
2328+
2329+
input_node_a_id = kwargs["index_lookup"][inputs[0][0]]
2330+
input_node_b_id = kwargs["index_lookup"][inputs[1][0]]
2331+
2332+
input_node_a = proc_nodes[input_node_a_id].name
2333+
input_node_b = proc_nodes[input_node_b_id].name
2334+
2335+
xor_node = onnx.helper.make_node(
2336+
"Xor",
2337+
[input_node_a, input_node_b],
2338+
[name],
2339+
name=name,
2340+
)
2341+
2342+
return [xor_node]
2343+
2344+
@mx_op.register("logical_not")
2345+
def convert_logical_not(node, **kwargs):
2346+
"""Map MXNet's logical not operator attributes to onnx's Not operator
2347+
and return the created node.
2348+
"""
2349+
onnx = import_onnx_modules()
2350+
name = node["name"]
2351+
proc_nodes = kwargs["proc_nodes"]
2352+
inputs = node["inputs"]
2353+
2354+
input_node_id = kwargs["index_lookup"][inputs[0][0]]
2355+
input_node = proc_nodes[input_node_id].name
2356+
2357+
node = onnx.helper.make_node(
2358+
"Not",
2359+
[input_node],
2360+
[name],
2361+
name=name
2362+
)
2363+
2364+
return [node]

tests/python-pytest/onnx/export/mxnet_export_test.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,19 @@
5656
'https://s3.amazonaws.com/onnx-mxnet/model-zoo/inception_v2.tar.gz'
5757
}
5858

59+
def get_int_inputs(interval, shape):
60+
"""Helper to get integer input of given shape and range"""
61+
assert len(interval) == len(shape)
62+
inputs = []
63+
input_tensors = []
64+
for idx in range(len(interval)):
65+
low, high = interval[idx]
66+
inputs.append(np.random.randint(low, high, size=shape[idx]).astype("float32"))
67+
input_tensors.append(helper.make_tensor_value_info("input"+str(idx+1),
68+
TensorProto.FLOAT, shape=shape[idx]))
69+
70+
return inputs, input_tensors
71+
5972
def get_test_files(name):
6073
"""Extract tar file and returns model path and input, output data"""
6174
tar_name = download(URLS.get(name), dirname=CURR_PATH.__str__())
@@ -238,6 +251,70 @@ def test_square():
238251

239252
npt.assert_almost_equal(result, numpy_op)
240253

254+
@with_seed()
255+
def test_logical_and():
256+
"""Test for logical and in onnx operators."""
257+
inputs, input_tensor = get_int_inputs([(0, 2), (0, 2)], [(3, 4, 5), (3, 4, 5)])
258+
outputs = [helper.make_tensor_value_info("output", TensorProto.FLOAT, shape=np.shape(inputs[0]))]
259+
nodes = [helper.make_node("And", ["input1", "input2"], ["output"])]
260+
graph = helper.make_graph(nodes,
261+
"and_test",
262+
input_tensor,
263+
outputs)
264+
model = helper.make_model(graph)
265+
bkd_rep = backend.prepare(model)
266+
output = bkd_rep.run([inputs[0], inputs[1]])
267+
numpy_op = np.logical_and(inputs[0], inputs[1]).astype(np.float32)
268+
npt.assert_almost_equal(output[0], numpy_op)
269+
270+
@with_seed()
271+
def test_logical_or():
272+
"""Test for logical or in onnx operators."""
273+
inputs, input_tensor = get_int_inputs([(0, 2), (0, 2)], [(3, 4, 5), (3, 4, 5)])
274+
outputs = [helper.make_tensor_value_info("output", TensorProto.FLOAT, shape=np.shape(inputs[0]))]
275+
nodes = [helper.make_node("Or", ["input1", "input2"], ["output"])]
276+
graph = helper.make_graph(nodes,
277+
"or_test",
278+
input_tensor,
279+
outputs)
280+
model = helper.make_model(graph)
281+
bkd_rep = backend.prepare(model)
282+
output = bkd_rep.run([inputs[0], inputs[1]])
283+
numpy_op = np.logical_or(inputs[0], inputs[1]).astype(np.float32)
284+
npt.assert_almost_equal(output[0], numpy_op)
285+
286+
@with_seed()
287+
def test_logical_not():
288+
"""Test for logical not in onnx operators."""
289+
inputs, input_tensor = get_int_inputs([(0, 2)], [(3, 4, 5)])
290+
outputs = [helper.make_tensor_value_info("output", TensorProto.FLOAT, shape=np.shape(inputs[0]))]
291+
nodes = [helper.make_node("Not", ["input1"], ["output"])]
292+
graph = helper.make_graph(nodes,
293+
"not_test",
294+
input_tensor,
295+
outputs)
296+
model = helper.make_model(graph)
297+
bkd_rep = backend.prepare(model)
298+
output = bkd_rep.run([inputs[0]])
299+
numpy_op = np.logical_not(inputs[0]).astype(np.float32)
300+
npt.assert_almost_equal(output[0], numpy_op)
301+
302+
@with_seed()
303+
def test_logical_xor():
304+
"""Test for logical xor in onnx operators."""
305+
inputs, input_tensor = get_int_inputs([(0, 2), (0, 2)], [(3, 4, 5), (3, 4, 5)])
306+
outputs = [helper.make_tensor_value_info("output", TensorProto.FLOAT, shape=np.shape(inputs[0]))]
307+
nodes = [helper.make_node("Xor", ["input1", "input2"], ["output"])]
308+
graph = helper.make_graph(nodes,
309+
"xor_test",
310+
input_tensor,
311+
outputs)
312+
model = helper.make_model(graph)
313+
bkd_rep = backend.prepare(model)
314+
output = bkd_rep.run([inputs[0], inputs[1]])
315+
numpy_op = np.logical_xor(inputs[0], inputs[1]).astype(np.float32)
316+
npt.assert_almost_equal(output[0], numpy_op)
317+
241318
if __name__ == '__main__':
242319
test_models("bvlc_googlenet", (1, 3, 224, 224), (1, 1000))
243320
test_models("bvlc_reference_caffenet", (1, 3, 224, 224), (1, 1000))

tests/python-pytest/onnx/import/test_cases.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@
5555
'test_argmax',
5656
'test_argmin',
5757
'test_min',
58-
'test_logical_',
5958
# enabling partial test cases for matmul
6059
'test_matmul_3d',
6160
'test_matmul_4d',

0 commit comments

Comments
 (0)