Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 10e5e02

Browse files
committed
Update test_numpy_ndarray.py
1 parent bebb165 commit 10e5e02

File tree

1 file changed

+12
-7
lines changed

1 file changed

+12
-7
lines changed

tests/python/unittest/test_numpy_ndarray.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -642,13 +642,18 @@ def test_getitem(np_array, index):
642642
)
643643
np_indexed_array = np_array[np_index]
644644
mx_np_array = np.array(np_array, dtype=np_array.dtype)
645-
try:
646-
mx_indexed_array = mx_np_array[index]
647-
except Exception as e:
648-
print('Failed with index = {}'.format(index))
649-
raise e
650-
mx_indexed_array = mx_indexed_array.asnumpy()
651-
assert same(np_indexed_array, mx_indexed_array), 'Failed with index = {}'.format(index)
645+
for autograd in [True, False]:
646+
try:
647+
if autograd:
648+
with mx.autograd.record():
649+
mx_indexed_array = mx_np_array[index]
650+
else:
651+
mx_indexed_array = mx_np_array[index]
652+
except Exception as e:
653+
print('Failed with index = {}'.format(index))
654+
raise e
655+
mx_indexed_array = mx_indexed_array.asnumpy()
656+
assert same(np_indexed_array, mx_indexed_array), 'Failed with index = {}'.format(index)
652657

653658
def test_setitem(np_array, index):
654659
def assert_same(np_array, np_index, mx_array, mx_index, mx_value, np_value=None):

0 commit comments

Comments
 (0)