@@ -847,26 +847,32 @@ def _basic_indexing_slice_is_contiguous(slc_key, shape):
847847 """Whether indexing with the given key results in a contiguous array.
848848
849849 The rule is: From right to left, if in an axis, a slice produces a
850- proper subset, no later axis can produce a proper subset or use
851- a step different from 1.
850+ proper subset, the later slice must have <=1 elements.
852851
853852 The ``slc_key`` sequence must have the same length as ``shape`` and
854853 only contain `slice` objects.
855854 """
856855 assert len (slc_key ) == len (shape )
857- subset = False
856+ is_subset = False
857+ total_sliced_elements = np .prod ([_get_slice_len (slc , n )
858+ for slc , n in zip (slc_key , shape )])
859+ if total_sliced_elements in (0 , 1 ):
860+ return True
858861 for idx , n in zip (reversed (slc_key ), reversed (shape )):
859- start , stop , step = idx .indices (n )
860- if step > 0 :
861- num = int ( np . ceil ( max ( stop - start , 0 ) / step ))
862- else :
863- num = int ( np . ceil ( min ( stop - start , 0 ) / step ))
864-
865- if num != 1 and ( subset or step != 1 ):
862+ _ , _ , step = idx .indices (n )
863+ num_elements = _get_slice_len ( idx , n )
864+ if num_elements == 0 :
865+ return True
866+ elif num_elements > 1 and ( step > 1 or step < 0 ):
867+ # We do not support the case of reverse slicing of multiple elements and
868+ # forward slicing of #elements > 1 and step > 1
866869 return False
867- if num != n :
868- subset = True
869-
870+ elif is_subset :
871+ if num_elements > 1 :
872+ return False
873+ else :
874+ if num_elements < n :
875+ is_subset = True
870876 return True
871877 # pylint: enable=invalid-name
872878
@@ -875,30 +881,27 @@ def _basic_indexing_sliced_shape(slc_key, shape):
875881 """Return the shape after slicing with the given key."""
876882 assert len (slc_key ) == len (shape )
877883 sliced_shape = []
878- for idx , n in zip (slc_key , shape ):
879- start , stop , step = idx .indices (n )
880- if step > 0 :
881- num = int (np .ceil (max (stop - start , 0 ) / step ))
882- else :
883- num = int (np .ceil (min (stop - start , 0 ) / step ))
884- sliced_shape .append (num )
885-
884+ for slc , n in zip (slc_key , shape ):
885+ num_elements = _get_slice_len (slc , n )
886+ sliced_shape .append (num_elements )
886887 return tuple (sliced_shape )
887888
888889 # pylint: disable=invalid-name
889890 @staticmethod
890891 def _basic_indexing_contiguous_flat_begin_end (slc_key , shape ):
891892 """Return the flat indices of begin and end for contiguous slicing."""
892893 assert len (slc_key ) == len (shape )
893- begin , end , _ = slc_key [0 ].indices (shape [0 ])
894- flat_begin , flat_end = begin , end - 1
895- for idx , n in zip (slc_key [1 :], shape [1 :]):
894+ flat_begin , flat_end = 0 , 0
895+ for slc , n in zip (slc_key , shape ):
896896 flat_begin *= n
897897 flat_end *= n
898- begin , end , _ = idx .indices (n )
899- flat_begin += begin
900- flat_end += end - 1
901-
898+ begin , _ , _ = slc .indices (n )
899+ num_elements = _get_slice_len (slc , n )
900+ if num_elements == 0 :
901+ return 0 , 0
902+ else :
903+ flat_begin += begin
904+ flat_end += begin + num_elements - 1
902905 return flat_begin , flat_end + 1
903906 # pylint: enable=invalid-name
904907
@@ -1062,7 +1065,7 @@ def _get_nd_basic_indexing(self, key):
10621065 for ax in new_axes : # pylint: disable=invalid-name
10631066 final_shape .insert (ax , 1 )
10641067
1065- if final_shape == [] :
1068+ if len ( final_shape ) == 0 :
10661069 # Override for single element indexing
10671070 final_shape = [1 ]
10681071 return sliced .reshape (final_shape )
@@ -3125,6 +3128,26 @@ def _get_dim_size(start, stop, step):
31253128 return dim_size
31263129
31273130
3131+ def _get_slice_len (slc , seq_length ):
3132+ """Given a python slice object and the length of the sequence, calculate the number of elements
3133+ in the slice.
3134+
3135+ Parameters
3136+ ----------
3137+ slc : py_slice
3138+ The slice object
3139+ seq_length : int
3140+ The length of the object you are going to apply the slice on
3141+
3142+ Returns
3143+ -------
3144+ ret : int
3145+ Total number of elements in the slice
3146+ """
3147+ start , stop , step = slc .indices (seq_length )
3148+ return max (0 , (stop - start + (step - (1 if step > 0 else - 1 ))) // step )
3149+
3150+
31283151def _get_broadcast_shape (shape1 , shape2 ):
31293152 """Given two shapes that are not identical, find the shape
31303153 that both input shapes can broadcast to."""
0 commit comments