@@ -1526,17 +1526,32 @@ def check_nearest_upsampling_with_shape(shapes, scale, root_scale):
15261526 assert_allclose (arr [name ].asnumpy ()* root_scale ** 2 * scale ** (2 * k ), arr_grad [name ].asnumpy (), rtol = 1e-4 )
15271527
15281528
1529- def check_bilinear_upsampling_with_shape (shapes , scale , root_scale ):
1530- arr = {'arg_%d' % i : mx .random .uniform (- 10.0 , 10.0 , shape , ctx = mx .cpu ()).copyto (default_context ()) for i , shape in zip (range (len (shapes )), shapes )}
1531- arr_grad = {'arg_%d' % i : mx .nd .zeros (shape ) for i , shape in zip (range (len (shapes )), shapes )}
1532-
1533- up = mx .sym .UpSampling (* [mx .sym .Variable ('arg_%d' % i ) for i in range (len (shapes ))], sample_type = 'bilinear' , scale = root_scale )
1529+ def check_bilinear_upsampling_with_shape (data_shape , weight_shape , scale , root_scale , num_filter ):
1530+ def _init_bilinear (arr , f ):
1531+ weight = np .zeros (np .prod (arr .shape ), dtype = 'float32' )
1532+ shape = arr .shape
1533+ c = (2 * f - 1 - f % 2 ) / (2. * f )
1534+ for i in range (np .prod (shape )):
1535+ x = i % shape [3 ]
1536+ y = (i // shape [3 ]) % shape [2 ]
1537+ weight [i ] = (1 - abs (x / f - c )) * (1 - abs (y / f - c ))
1538+ arr [:] = weight .reshape (shape )
1539+ return arr
1540+
1541+ up = mx .sym .UpSampling (mx .sym .Variable ("data" ),
1542+ mx .sym .Variable ('weight' ), sample_type = 'bilinear' , scale = root_scale ,
1543+ num_filter = num_filter , num_args = 2 )
1544+ arg_shapes , out_shapes , _ = up .infer_shape (data = data_shape )
1545+ arr = {'data' : mx .random .uniform (- 5 , 5 , data_shape , ctx = mx .cpu ()).copyto (default_context ()),
1546+ 'weight' : mx .nd .array (_init_bilinear (mx .ndarray .empty (arg_shapes [1 ]).asnumpy (), root_scale ))}
1547+
1548+ arr_grad = [mx .nd .empty (s ) for s in arg_shapes ]
15341549 exe = up .bind (default_context (), args = arr , args_grad = arr_grad )
15351550 exe .forward (is_train = True )
1551+ out = exe .outputs [0 ].asnumpy ()
15361552 exe .backward (exe .outputs )
1537- for k in range (len (shapes )):
1538- name = 'arg_%d' % k
1539- assert_allclose (arr [name ].asnumpy ()* root_scale ** 2 * scale ** (2 * k ), arr_grad [name ].asnumpy (), rtol = 1e-4 )
1553+ target_shape = (data_shape [2 ] * root_scale , data_shape [3 ] * root_scale )
1554+ assert out .shape == data_shape [:2 ] + target_shape
15401555
15411556
15421557@with_seed ()
@@ -1549,6 +1564,22 @@ def test_nearest_upsampling():
15491564 check_nearest_upsampling_with_shape (shapes , scale , root_scale )
15501565
15511566
1567+ @with_seed ()
1568+ def test_bilinear_upsampling ():
1569+ rootscale = [2 ,3 ]
1570+ scales = [1 ,2 ,3 ]
1571+ filters = [1 ,2 ,3 ]
1572+ bases = [1 ,2 ,3 ]
1573+ for params in itertools .product (rootscale , scales , filters , bases ):
1574+ root_scale , scale , num_filter , base = params
1575+ # bilinear upsampling takes only 1 data and 1 weight
1576+ # multi input mode is not applicable
1577+ dimension = base * root_scale * scale
1578+ kernel = 2 * root_scale - root_scale % 2
1579+ data_shape = (1 , num_filter , dimension , dimension )
1580+ weight_shape = (1 , num_filter , kernel , kernel )
1581+ check_bilinear_upsampling_with_shape (data_shape , weight_shape , scale , root_scale , num_filter )
1582+
15521583@with_seed ()
15531584def test_batchnorm_training ():
15541585 def check_batchnorm_training (stype ):
0 commit comments