1515# specific language governing permissions and limitations
1616# under the License.
1717
18- import mxnet as mx
1918import numpy as np
19+ import mxnet as mx
2020from mxnet.test_utils import rand_ndarray, assert_almost_equal
2121from mxnet import gluon, nd
2222from tests.python.unittest.common import with_seed
@@ -74,7 +74,7 @@ def test_ndarray_random_randint():
7474 # check if randint can generate value greater than 2**32 (large)
7575 low_large_value = 2**32
7676 high_large_value = 2**34
77- a = nd.random.randint(low_large_value,high_large_value, dtype=np.int64)
77+ a = nd.random.randint(low_large_value, high_large_value, dtype=np.int64)
7878 low = mx.nd.array([low_large_value], dtype='int64')
7979 high = mx.nd.array([high_large_value], dtype='int64')
8080 assert a.__gt__(low) and a.__lt__(high)
@@ -130,13 +130,6 @@ def test_clip():
130130 assert np.sum(res[-1].asnumpy() == 1000) == a.shape[1]
131131
132132
133- def test_take():
134- a = nd.ones(shape=(LARGE_X, SMALL_Y))
135- idx = nd.arange(LARGE_X-1000, LARGE_X)
136- res = nd.take(a, idx)
137- assert np.sum(res[-1].asnumpy() == 1) == res.shape[1]
138-
139-
140133def test_split():
141134 a = nd.arange(0, LARGE_X * SMALL_Y).reshape(LARGE_X, SMALL_Y)
142135 outs = nd.split(a, num_outputs=SMALL_Y, axis=1)
@@ -215,9 +208,9 @@ def test_where():
215208
216209
217210def test_pick():
218- a = mx.nd.ones(shape=(256* 35, 1024* 1024))
219- b = mx.nd.ones(shape=(256* 35,))
220- res = mx.nd.pick(a,b)
211+ a = mx.nd.ones(shape=(256 * 35, 1024 * 1024))
212+ b = mx.nd.ones(shape=(256 * 35, ))
213+ res = mx.nd.pick(a, b)
221214 assert res.shape == b.shape
222215
223216
@@ -252,11 +245,9 @@ def numpy_space_to_depth(x, blocksize):
252245 output = mx.nd.space_to_depth(data, 2)
253246 assert_almost_equal(output.asnumpy(), expected, atol=1e-3, rtol=1e-3)
254247
255-
248+ @with_seed()
256249def test_diag():
257- h = np.random.randint(2,9)
258- w = np.random.randint(2,9)
259- a_np = np.random.random((LARGE_X, 64)).astype(np.float32)
250+ a_np = np.random.random((LARGE_X, SMALL_Y)).astype(np.float32)
260251 a = mx.nd.array(a_np)
261252
262253 # k == 0
@@ -274,11 +265,40 @@ def test_diag():
274265 assert_almost_equal(r.asnumpy(), np.diag(a_np, k=k))
275266
276267 # random k
277- k = np.random.randint(-min(LARGE_X, 64 ) + 1, min(h, w ))
268+ k = np.random.randint(-min(LARGE_X, SMALL_Y ) + 1, min(LARGE_X, SMALL_Y ))
278269 r = mx.nd.diag(a, k=k)
279270 assert_almost_equal(r.asnumpy(), np.diag(a_np, k=k))
280271
281272
273+ def random_2d_coordinates(x_low, x_high, y_low, y_high):
274+ x = np.random.randint(x_low, x_high, dtype=np.int64)
275+ y = np.random.randint(y_low, y_high, dtype=np.int64)
276+ return x, y
277+
278+
279+ @with_seed()
280+ def test_ravel_multi_index():
281+ x1, y1 = random_2d_coordinates((LARGE_X - 100), LARGE_X, 10, SMALL_Y)
282+ x2, y2 = random_2d_coordinates((LARGE_X - 200), LARGE_X, 9, SMALL_Y)
283+ x3, y3 = random_2d_coordinates((LARGE_X - 300), LARGE_X, 8, SMALL_Y)
284+ indices_2d = [[x1, x2, x3], [y1, y2, y3]]
285+ idx = mx.nd.ravel_multi_index(mx.nd.array(indices_2d, dtype=np.int64), shape=(LARGE_X, SMALL_Y))
286+ idx_numpy = np.ravel_multi_index(indices_2d, (LARGE_X, SMALL_Y))
287+ assert np.sum(1 for i in range(idx.size) if idx[i] == idx_numpy[i]) == 3
288+
289+
290+ @with_seed()
291+ def test_unravel_index():
292+ x1, y1 = random_2d_coordinates((LARGE_X - 100), LARGE_X, 10, SMALL_Y)
293+ x2, y2 = random_2d_coordinates((LARGE_X - 200), LARGE_X, 9, SMALL_Y)
294+ x3, y3 = random_2d_coordinates((LARGE_X - 300), LARGE_X, 8, SMALL_Y)
295+ original_2d_indices = [[x1, x2, x3], [y1, y2, y3]]
296+ #idx = mx.np.ravel_multi_index(mx.nd.array(original_2d_indices, dtype=np.int64), shape=(LARGE_X, SMALL_Y))
297+ idx_numpy = np.ravel_multi_index(original_2d_indices, (LARGE_X, SMALL_Y))
298+ indices_2d = mx.nd.unravel_index(mx.nd.array(idx_numpy, dtype=np.int64), shape=(LARGE_X, SMALL_Y))
299+ assert (indices_2d.asnumpy() == np.array(original_2d_indices)).all()
300+
301+
282302if __name__ == '__main__':
283303 import nose
284304 nose.runmodule()
0 commit comments