1515# specific language governing permissions and limitations
1616# under the License.
1717
18- import mxnet as mx
1918import numpy as np
20- from mxnet .test_utils import rand_ndarray , assert_almost_equal
19+ import mxnet as mx
20+ from mxnet .test_utils import rand_ndarray , assert_almost_equal , rand_coord_2d
2121from mxnet import gluon , nd
2222from tests .python .unittest .common import with_seed
2323
@@ -74,7 +74,7 @@ def test_ndarray_random_randint():
7474 # check if randint can generate value greater than 2**32 (large)
7575 low_large_value = 2 ** 32
7676 high_large_value = 2 ** 34
77- a = nd .random .randint (low_large_value ,high_large_value , dtype = np .int64 )
77+ a = nd .random .randint (low_large_value , high_large_value , dtype = np .int64 )
7878 low = mx .nd .array ([low_large_value ], dtype = 'int64' )
7979 high = mx .nd .array ([high_large_value ], dtype = 'int64' )
8080 assert a .__gt__ (low ) and a .__lt__ (high )
@@ -130,13 +130,6 @@ def test_clip():
130130 assert np .sum (res [- 1 ].asnumpy () == 1000 ) == a .shape [1 ]
131131
132132
133- def test_take ():
134- a = nd .ones (shape = (LARGE_X , SMALL_Y ))
135- idx = nd .arange (LARGE_X - 1000 , LARGE_X )
136- res = nd .take (a , idx )
137- assert np .sum (res [- 1 ].asnumpy () == 1 ) == res .shape [1 ]
138-
139-
140133def test_split ():
141134 a = nd .arange (0 , LARGE_X * SMALL_Y ).reshape (LARGE_X , SMALL_Y )
142135 outs = nd .split (a , num_outputs = SMALL_Y , axis = 1 )
@@ -215,9 +208,9 @@ def test_where():
215208
216209
217210def test_pick ():
218- a = mx .nd .ones (shape = (256 * 35 , 1024 * 1024 ))
219- b = mx .nd .ones (shape = (256 * 35 ,))
220- res = mx .nd .pick (a ,b )
211+ a = mx .nd .ones (shape = (256 * 35 , 1024 * 1024 ))
212+ b = mx .nd .ones (shape = (256 * 35 , ))
213+ res = mx .nd .pick (a , b )
221214 assert res .shape == b .shape
222215
223216
@@ -252,11 +245,9 @@ def numpy_space_to_depth(x, blocksize):
252245 output = mx .nd .space_to_depth (data , 2 )
253246 assert_almost_equal (output .asnumpy (), expected , atol = 1e-3 , rtol = 1e-3 )
254247
255-
248+ @ with_seed ()
256249def test_diag ():
257- h = np .random .randint (2 ,9 )
258- w = np .random .randint (2 ,9 )
259- a_np = np .random .random ((LARGE_X , 64 )).astype (np .float32 )
250+ a_np = np .random .random ((LARGE_X , SMALL_Y )).astype (np .float32 )
260251 a = mx .nd .array (a_np )
261252
262253 # k == 0
@@ -274,11 +265,33 @@ def test_diag():
274265 assert_almost_equal (r .asnumpy (), np .diag (a_np , k = k ))
275266
276267 # random k
277- k = np .random .randint (- min (LARGE_X , 64 ) + 1 , min (h , w ))
268+ k = np .random .randint (- min (LARGE_X , SMALL_Y ) + 1 , min (LARGE_X , SMALL_Y ))
278269 r = mx .nd .diag (a , k = k )
279270 assert_almost_equal (r .asnumpy (), np .diag (a_np , k = k ))
280271
281272
273+ @with_seed ()
274+ def test_ravel_multi_index ():
275+ x1 , y1 = rand_coord_2d ((LARGE_X - 100 ), LARGE_X , 10 , SMALL_Y )
276+ x2 , y2 = rand_coord_2d ((LARGE_X - 200 ), LARGE_X , 9 , SMALL_Y )
277+ x3 , y3 = rand_coord_2d ((LARGE_X - 300 ), LARGE_X , 8 , SMALL_Y )
278+ indices_2d = [[x1 , x2 , x3 ], [y1 , y2 , y3 ]]
279+ idx = mx .nd .ravel_multi_index (mx .nd .array (indices_2d , dtype = np .int64 ), shape = (LARGE_X , SMALL_Y ))
280+ idx_numpy = np .ravel_multi_index (indices_2d , (LARGE_X , SMALL_Y ))
281+ assert np .sum (1 for i in range (idx .size ) if idx [i ] == idx_numpy [i ]) == 3
282+
283+
284+ @with_seed ()
285+ def test_unravel_index ():
286+ x1 , y1 = rand_coord_2d ((LARGE_X - 100 ), LARGE_X , 10 , SMALL_Y )
287+ x2 , y2 = rand_coord_2d ((LARGE_X - 200 ), LARGE_X , 9 , SMALL_Y )
288+ x3 , y3 = rand_coord_2d ((LARGE_X - 300 ), LARGE_X , 8 , SMALL_Y )
289+ original_2d_indices = [[x1 , x2 , x3 ], [y1 , y2 , y3 ]]
290+ idx_numpy = np .ravel_multi_index (original_2d_indices , (LARGE_X , SMALL_Y ))
291+ indices_2d = mx .nd .unravel_index (mx .nd .array (idx_numpy , dtype = np .int64 ), shape = (LARGE_X , SMALL_Y ))
292+ assert (indices_2d .asnumpy () == np .array (original_2d_indices )).all ()
293+
294+
282295if __name__ == '__main__' :
283296 import nose
284297 nose .runmodule ()
0 commit comments