Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 8f10d55

Browse files
sxjsciencehaojin2
authored andcommitted
[Numpy] Fix imperative basic indexing in numpy (#16902)
* fix bug add test case fix Update test_numpy_ndarray.py * revise function name
1 parent d2d4876 commit 8f10d55

File tree

4 files changed

+71
-39
lines changed

4 files changed

+71
-39
lines changed

python/mxnet/ndarray/ndarray.py

Lines changed: 52 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -847,26 +847,32 @@ def _basic_indexing_slice_is_contiguous(slc_key, shape):
847847
"""Whether indexing with the given key results in a contiguous array.
848848
849849
The rule is: From right to left, if in an axis, a slice produces a
850-
proper subset, no later axis can produce a proper subset or use
851-
a step different from 1.
850+
proper subset, the later slice must have <=1 elements.
852851
853852
The ``slc_key`` sequence must have the same length as ``shape`` and
854853
only contain `slice` objects.
855854
"""
856855
assert len(slc_key) == len(shape)
857-
subset = False
856+
is_subset = False
857+
total_sliced_elements = np.prod([_get_slice_len(slc, n)
858+
for slc, n in zip(slc_key, shape)])
859+
if total_sliced_elements in (0, 1):
860+
return True
858861
for idx, n in zip(reversed(slc_key), reversed(shape)):
859-
start, stop, step = idx.indices(n)
860-
if step > 0:
861-
num = int(np.ceil(max(stop - start, 0) / step))
862-
else:
863-
num = int(np.ceil(min(stop - start, 0) / step))
864-
865-
if num != 1 and (subset or step != 1):
862+
_, _, step = idx.indices(n)
863+
num_elements = _get_slice_len(idx, n)
864+
if num_elements == 0:
865+
return True
866+
elif num_elements > 1 and (step > 1 or step < 0):
867+
# We do not support the case of reverse slicing of multiple elements and
868+
# forward slicing of #elements > 1 and step > 1
866869
return False
867-
if num != n:
868-
subset = True
869-
870+
elif is_subset:
871+
if num_elements > 1:
872+
return False
873+
else:
874+
if num_elements < n:
875+
is_subset = True
870876
return True
871877
# pylint: enable=invalid-name
872878

@@ -875,30 +881,27 @@ def _basic_indexing_sliced_shape(slc_key, shape):
875881
"""Return the shape after slicing with the given key."""
876882
assert len(slc_key) == len(shape)
877883
sliced_shape = []
878-
for idx, n in zip(slc_key, shape):
879-
start, stop, step = idx.indices(n)
880-
if step > 0:
881-
num = int(np.ceil(max(stop - start, 0) / step))
882-
else:
883-
num = int(np.ceil(min(stop - start, 0) / step))
884-
sliced_shape.append(num)
885-
884+
for slc, n in zip(slc_key, shape):
885+
num_elements = _get_slice_len(slc, n)
886+
sliced_shape.append(num_elements)
886887
return tuple(sliced_shape)
887888

888889
# pylint: disable=invalid-name
889890
@staticmethod
890891
def _basic_indexing_contiguous_flat_begin_end(slc_key, shape):
891892
"""Return the flat indices of begin and end for contiguous slicing."""
892893
assert len(slc_key) == len(shape)
893-
begin, end, _ = slc_key[0].indices(shape[0])
894-
flat_begin, flat_end = begin, end - 1
895-
for idx, n in zip(slc_key[1:], shape[1:]):
894+
flat_begin, flat_end = 0, 0
895+
for slc, n in zip(slc_key, shape):
896896
flat_begin *= n
897897
flat_end *= n
898-
begin, end, _ = idx.indices(n)
899-
flat_begin += begin
900-
flat_end += end - 1
901-
898+
begin, _, _ = slc.indices(n)
899+
num_elements = _get_slice_len(slc, n)
900+
if num_elements == 0:
901+
return 0, 0
902+
else:
903+
flat_begin += begin
904+
flat_end += begin + num_elements - 1
902905
return flat_begin, flat_end + 1
903906
# pylint: enable=invalid-name
904907

@@ -1062,7 +1065,7 @@ def _get_nd_basic_indexing(self, key):
10621065
for ax in new_axes: # pylint: disable=invalid-name
10631066
final_shape.insert(ax, 1)
10641067

1065-
if final_shape == []:
1068+
if len(final_shape) == 0:
10661069
# Override for single element indexing
10671070
final_shape = [1]
10681071
return sliced.reshape(final_shape)
@@ -3125,6 +3128,26 @@ def _get_dim_size(start, stop, step):
31253128
return dim_size
31263129

31273130

3131+
def _get_slice_len(slc, seq_length):
3132+
"""Given a python slice object and the length of the sequence, calculate the number of elements
3133+
in the slice.
3134+
3135+
Parameters
3136+
----------
3137+
slc : py_slice
3138+
The slice object
3139+
seq_length : int
3140+
The length of the object you are going to apply the slice on
3141+
3142+
Returns
3143+
-------
3144+
ret : int
3145+
Total number of elements in the slice
3146+
"""
3147+
start, stop, step = slc.indices(seq_length)
3148+
return max(0, (stop - start + (step - (1 if step > 0 else -1))) // step)
3149+
3150+
31283151
def _get_broadcast_shape(shape1, shape2):
31293152
"""Given two shapes that are not identical, find the shape
31303153
that both input shapes can broadcast to."""

src/ndarray/ndarray.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -293,10 +293,9 @@ NDArray NDArray::Slice(index_t begin, index_t end) const {
293293
NDArray NDArray::SliceWithRecord(index_t begin, index_t end) {
294294
NDArray ret = this->Slice(begin, end);
295295
if (!Imperative::Get()->is_recording()) return ret;
296-
// fake a slice_axis op
296+
// fake a slice op
297297
nnvm::NodeAttrs attrs;
298-
attrs.op = nnvm::Op::Get("slice_axis");
299-
attrs.dict.insert({"axis", "0"});
298+
attrs.op = nnvm::Op::Get("slice");
300299
attrs.dict.insert({"begin", std::to_string(begin)});
301300
attrs.dict.insert({"end", std::to_string(end)});
302301
attrs.op->attr_parser(&attrs);

src/operator/nn/mkldnn/mkldnn_base-inl.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,10 @@ static inline bool SupportStorageMKLDNN(int stype) {
125125

126126
static inline bool SupportMKLDNN(int dtype, const mxnet::TShape &shape) {
127127
int ndim = shape.ndim();
128+
if (ndim == 0 || shape.Size() == 0) {
129+
// MKLDNN currently does not support 0-dim Tensor and 0-size Tensor
130+
return false;
131+
}
128132
return dtype == mshadow::kFloat32 && (ndim == 1 || ndim == 2 || ndim == 4);
129133
}
130134

tests/python/unittest/test_numpy_ndarray.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -642,13 +642,18 @@ def test_getitem(np_array, index):
642642
)
643643
np_indexed_array = np_array[np_index]
644644
mx_np_array = np.array(np_array, dtype=np_array.dtype)
645-
try:
646-
mx_indexed_array = mx_np_array[index]
647-
except Exception as e:
648-
print('Failed with index = {}'.format(index))
649-
raise e
650-
mx_indexed_array = mx_indexed_array.asnumpy()
651-
assert same(np_indexed_array, mx_indexed_array), 'Failed with index = {}'.format(index)
645+
for autograd in [True, False]:
646+
try:
647+
if autograd:
648+
with mx.autograd.record():
649+
mx_indexed_array = mx_np_array[index]
650+
else:
651+
mx_indexed_array = mx_np_array[index]
652+
except Exception as e:
653+
print('Failed with index = {}'.format(index))
654+
raise e
655+
mx_indexed_array = mx_indexed_array.asnumpy()
656+
assert same(np_indexed_array, mx_indexed_array), 'Failed with index = {}'.format(index)
652657

653658
def test_setitem(np_array, index):
654659
def assert_same(np_array, np_index, mx_array, mx_index, mx_value, np_value=None):
@@ -768,6 +773,7 @@ def test_setitem_autograd(np_array, index):
768773
np_int(slice(1, 5), np.int32),
769774
np_int(slice(1, 5), np.int64),
770775
slice(1, 5, 2),
776+
slice(1, 2, 2),
771777
np_int(slice(1, 5, 2), np.int32),
772778
np_int(slice(1, 5, 2), np.int64),
773779
slice(7, 0, -1),

0 commit comments

Comments
 (0)