Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 71046be

Browse files
author
Rohit Kumar Srivastava
committed
[MXNET-1408] Adding test to verify Large Tensor Support for ravel and unravel
1 parent 294a34a commit 71046be

File tree

2 files changed

+37
-18
lines changed

2 files changed

+37
-18
lines changed

python/mxnet/test_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,12 @@ def rand_shape_nd(num_dim, dim=10):
419419
return tuple(rnd.randint(1, dim+1, size=num_dim))
420420

421421

422+
def rand_coord_2d(x_low, x_high, y_low, y_high):
423+
x = np.random.randint(x_low, x_high, dtype=np.int64)
424+
y = np.random.randint(y_low, y_high, dtype=np.int64)
425+
return x, y
426+
427+
422428
def np_reduce(dat, axis, keepdims, numpy_reduce_func):
423429
"""Compatible reduce for old version of NumPy.
424430

tests/nightly/test_large_array.py

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18-
import mxnet as mx
1918
import numpy as np
20-
from mxnet.test_utils import rand_ndarray, assert_almost_equal
19+
import mxnet as mx
20+
from mxnet.test_utils import rand_ndarray, assert_almost_equal, rand_coord_2d
2121
from mxnet import gluon, nd
2222
from tests.python.unittest.common import with_seed
2323

@@ -74,7 +74,7 @@ def test_ndarray_random_randint():
7474
# check if randint can generate value greater than 2**32 (large)
7575
low_large_value = 2**32
7676
high_large_value = 2**34
77-
a = nd.random.randint(low_large_value,high_large_value, dtype=np.int64)
77+
a = nd.random.randint(low_large_value, high_large_value, dtype=np.int64)
7878
low = mx.nd.array([low_large_value], dtype='int64')
7979
high = mx.nd.array([high_large_value], dtype='int64')
8080
assert a.__gt__(low) and a.__lt__(high)
@@ -130,13 +130,6 @@ def test_clip():
130130
assert np.sum(res[-1].asnumpy() == 1000) == a.shape[1]
131131

132132

133-
def test_take():
134-
a = nd.ones(shape=(LARGE_X, SMALL_Y))
135-
idx = nd.arange(LARGE_X-1000, LARGE_X)
136-
res = nd.take(a, idx)
137-
assert np.sum(res[-1].asnumpy() == 1) == res.shape[1]
138-
139-
140133
def test_split():
141134
a = nd.arange(0, LARGE_X * SMALL_Y).reshape(LARGE_X, SMALL_Y)
142135
outs = nd.split(a, num_outputs=SMALL_Y, axis=1)
@@ -215,9 +208,9 @@ def test_where():
215208

216209

217210
def test_pick():
218-
a = mx.nd.ones(shape=(256*35, 1024*1024))
219-
b = mx.nd.ones(shape=(256*35,))
220-
res = mx.nd.pick(a,b)
211+
a = mx.nd.ones(shape=(256 * 35, 1024 * 1024))
212+
b = mx.nd.ones(shape=(256 * 35, ))
213+
res = mx.nd.pick(a, b)
221214
assert res.shape == b.shape
222215

223216

@@ -252,11 +245,9 @@ def numpy_space_to_depth(x, blocksize):
252245
output = mx.nd.space_to_depth(data, 2)
253246
assert_almost_equal(output.asnumpy(), expected, atol=1e-3, rtol=1e-3)
254247

255-
248+
@with_seed()
256249
def test_diag():
257-
h = np.random.randint(2,9)
258-
w = np.random.randint(2,9)
259-
a_np = np.random.random((LARGE_X, 64)).astype(np.float32)
250+
a_np = np.random.random((LARGE_X, SMALL_Y)).astype(np.float32)
260251
a = mx.nd.array(a_np)
261252

262253
# k == 0
@@ -274,11 +265,33 @@ def test_diag():
274265
assert_almost_equal(r.asnumpy(), np.diag(a_np, k=k))
275266

276267
# random k
277-
k = np.random.randint(-min(LARGE_X, 64) + 1, min(h, w))
268+
k = np.random.randint(-min(LARGE_X, SMALL_Y) + 1, min(LARGE_X, SMALL_Y))
278269
r = mx.nd.diag(a, k=k)
279270
assert_almost_equal(r.asnumpy(), np.diag(a_np, k=k))
280271

281272

273+
@with_seed()
274+
def test_ravel_multi_index():
275+
x1, y1 = rand_coord_2d((LARGE_X - 100), LARGE_X, 10, SMALL_Y)
276+
x2, y2 = rand_coord_2d((LARGE_X - 200), LARGE_X, 9, SMALL_Y)
277+
x3, y3 = rand_coord_2d((LARGE_X - 300), LARGE_X, 8, SMALL_Y)
278+
indices_2d = [[x1, x2, x3], [y1, y2, y3]]
279+
idx = mx.nd.ravel_multi_index(mx.nd.array(indices_2d, dtype=np.int64), shape=(LARGE_X, SMALL_Y))
280+
idx_numpy = np.ravel_multi_index(indices_2d, (LARGE_X, SMALL_Y))
281+
assert np.sum(1 for i in range(idx.size) if idx[i] == idx_numpy[i]) == 3
282+
283+
284+
@with_seed()
285+
def test_unravel_index():
286+
x1, y1 = rand_coord_2d((LARGE_X - 100), LARGE_X, 10, SMALL_Y)
287+
x2, y2 = rand_coord_2d((LARGE_X - 200), LARGE_X, 9, SMALL_Y)
288+
x3, y3 = rand_coord_2d((LARGE_X - 300), LARGE_X, 8, SMALL_Y)
289+
original_2d_indices = [[x1, x2, x3], [y1, y2, y3]]
290+
idx_numpy = np.ravel_multi_index(original_2d_indices, (LARGE_X, SMALL_Y))
291+
indices_2d = mx.nd.unravel_index(mx.nd.array(idx_numpy, dtype=np.int64), shape=(LARGE_X, SMALL_Y))
292+
assert (indices_2d.asnumpy() == np.array(original_2d_indices)).all()
293+
294+
282295
if __name__ == '__main__':
283296
import nose
284297
nose.runmodule()

0 commit comments

Comments
 (0)