1515# specific language governing permissions and limitations
1616# under the License.
1717
18- import mxnet as mx
1918import numpy as np
19+ import mxnet as mx
2020from mxnet .test_utils import rand_ndarray , assert_almost_equal
2121from mxnet import gluon , nd
2222from tests .python .unittest .common import with_seed
@@ -74,7 +74,7 @@ def test_ndarray_random_randint():
7474 # check if randint can generate value greater than 2**32 (large)
7575 low_large_value = 2 ** 32
7676 high_large_value = 2 ** 34
77- a = nd .random .randint (low_large_value ,high_large_value , dtype = np .int64 )
77+ a = nd .random .randint (low_large_value , high_large_value , dtype = np .int64 )
7878 low = mx .nd .array ([low_large_value ], dtype = 'int64' )
7979 high = mx .nd .array ([high_large_value ], dtype = 'int64' )
8080 assert a .__gt__ (low ) and a .__lt__ (high )
@@ -130,13 +130,6 @@ def test_clip():
130130 assert np .sum (res [- 1 ].asnumpy () == 1000 ) == a .shape [1 ]
131131
132132
133- def test_take ():
134- a = nd .ones (shape = (LARGE_X , SMALL_Y ))
135- idx = nd .arange (LARGE_X - 1000 , LARGE_X )
136- res = nd .take (a , idx )
137- assert np .sum (res [- 1 ].asnumpy () == 1 ) == res .shape [1 ]
138-
139-
140133def test_split ():
141134 a = nd .arange (0 , LARGE_X * SMALL_Y ).reshape (LARGE_X , SMALL_Y )
142135 outs = nd .split (a , num_outputs = SMALL_Y , axis = 1 )
@@ -216,8 +209,8 @@ def test_where():
216209
217210def test_pick ():
218211 a = mx .nd .ones (shape = (256 * 35 , 1024 * 1024 ))
219- b = mx .nd .ones (shape = (256 * 35 ,))
220- res = mx .nd .pick (a ,b )
212+ b = mx .nd .ones (shape = (256 * 35 , ))
213+ res = mx .nd .pick (a , b )
221214 assert res .shape == b .shape
222215
223216
@@ -252,11 +245,9 @@ def numpy_space_to_depth(x, blocksize):
252245 output = mx .nd .space_to_depth (data , 2 )
253246 assert_almost_equal (output .asnumpy (), expected , atol = 1e-3 , rtol = 1e-3 )
254247
255-
248+ @ with_seed ()
256249def test_diag ():
257- h = np .random .randint (2 ,9 )
258- w = np .random .randint (2 ,9 )
259- a_np = np .random .random ((LARGE_X , 64 )).astype (np .float32 )
250+ a_np = np .random .random ((LARGE_X , SMALL_Y )).astype (np .float32 )
260251 a = mx .nd .array (a_np )
261252
262253 # k == 0
@@ -274,11 +265,26 @@ def test_diag():
274265 assert_almost_equal (r .asnumpy (), np .diag (a_np , k = k ))
275266
276267 # random k
277- k = np .random .randint (- min (LARGE_X , 64 ) + 1 , min (h , w ))
268+ k = np .random .randint (- min (LARGE_X , SMALL_Y ) + 1 , min (LARGE_X , SMALL_Y ))
278269 r = mx .nd .diag (a , k = k )
279270 assert_almost_equal (r .asnumpy (), np .diag (a_np , k = k ))
280271
281272
273+ def test_ravel_multi_index ():
274+ indices_2d = [[LARGE_X - 1 , LARGE_X - 100 , 6 ], [SMALL_Y - 1 , SMALL_Y - 10 , 1 ]]
275+ idx = mx .nd .ravel_multi_index (mx .nd .array (indices_2d , dtype = np .int64 ), shape = (LARGE_X , SMALL_Y ))
276+ idx_numpy = np .ravel_multi_index (indices_2d , (LARGE_X , SMALL_Y ))
277+ assert np .sum (1 for i in range (idx .size ) if idx [i ] == idx_numpy [i ]) == 3
278+
279+
280+ def test_unravel_index ():
281+ original_2d_indices = [[LARGE_X - 1 , LARGE_X - 100 , 6 ], [SMALL_Y - 1 , SMALL_Y - 10 , 1 ]]
282+ idx = mx .nd .ravel_multi_index (mx .nd .array (original_2d_indices , dtype = np .int64 ), shape = (LARGE_X , SMALL_Y ))
283+ idx_numpy = np .ravel_multi_index (original_2d_indices , (LARGE_X , SMALL_Y ))
284+ indices_2d = mx .nd .unravel_index (mx .nd .array (idx , dtype = np .int64 ), shape = (LARGE_X , SMALL_Y ))
285+ assert (indices_2d .asnumpy () == np .array (original_2d_indices )).all ()
286+
287+
282288if __name__ == '__main__' :
283289 import nose
284290 nose .runmodule ()
0 commit comments