@@ -302,7 +302,7 @@ def __getitem__(self, key):
302302 except Exception as err :
303303 raise TypeError ('{}' .format (str (err )))
304304 if isinstance (key , _np .ndarray ) and key .dtype == _np .bool_ :
305- key = array (key , dtype = 'bool' )
305+ key = array (key , dtype = 'bool' , ctx = self . ctx )
306306 if isinstance (key , ndarray ) and key .dtype == _np .bool_ : # boolean indexing
307307 key_shape = key .shape
308308 key_ndim = len (key_shape )
@@ -364,6 +364,8 @@ def __setitem__(self, key, value):
364364 """
365365 if isinstance (value , NDArray ) and not isinstance (value , ndarray ):
366366 raise TypeError ('Cannot assign mx.nd.NDArray to mxnet.numpy.ndarray' )
367+
368+ # handle basic and advanced indexing
367369 if self .ndim == 0 :
368370 if not isinstance (key , tuple ) or len (key ) != 0 :
369371 raise IndexError ('scalar tensor can only accept `()` as index' )
@@ -753,7 +755,7 @@ def detach(self):
753755 check_call (_LIB .MXNDArrayDetach (self .handle , ctypes .byref (hdl )))
754756 return _np_ndarray_cls (hdl )
755757
756- def astype (self , dtype , * args , * *kwargs ): # pylint: disable=arguments-differ,unused-argument
758+ def astype (self , dtype , ** kwargs ): # pylint: disable=arguments-differ,unused-argument
757759 """
758760 Copy of the array, cast to a specified type.
759761
@@ -1237,7 +1239,14 @@ def tile(self, *args, **kwargs):
12371239
12381240 def transpose (self , * axes ): # pylint: disable=arguments-differ
12391241 """Permute the dimensions of an array."""
1240- return _mx_np_op .transpose (self , axes = axes if len (axes ) != 0 else None )
1242+ if len (axes ) == 0 :
1243+ axes = None
1244+ elif len (axes ) == 1 :
1245+ if isinstance (axes [0 ], (tuple , list )):
1246+ axes = axes [0 ]
1247+ elif axes [0 ] is None :
1248+ axes = None
1249+ return _mx_np_op .transpose (self , axes = axes )
12411250
12421251 def flip (self , * args , ** kwargs ):
12431252 """Convenience fluent method for :py:func:`flip`.
@@ -3401,11 +3410,11 @@ def logical_not(x, out=None, **kwargs):
34013410 --------
34023411 >>> x= np.array([True, False, 0, 1])
34033412 >>> np.logical_not(x)
3404- array([0., 1., 1., 0. ])
3413+ array([False, True, True, False ])
34053414
34063415 >>> x = np.arange(5)
34073416 >>> np.logical_not(x<3)
3408- array([0., 0., 0., 1., 1. ])
3417+ array([False, False, False, True, True ])
34093418 """
34103419 return _mx_nd_np .logical_not (x , out = out , ** kwargs )
34113420
0 commit comments