@@ -427,9 +427,15 @@ def hybrid_forward(self, F, seq):
427427 assert_almost_equal (output1 .asnumpy (), output2 .asnumpy ())
428428
429429
430- def check_rnn_layer_forward (layer , inputs , states = None , run_only = False ):
431- layer .collect_params ().initialize ()
430+ def check_rnn_layer_forward (layer , inputs , states = None , run_only = False , ctx = mx .cpu ()):
431+ layer .collect_params ().initialize (ctx = ctx )
432+ inputs = inputs .as_in_context (ctx )
432433 inputs .attach_grad ()
434+ if states is not None :
435+ if isinstance (states , (list , tuple )):
436+ states = [s .as_in_context (ctx ) for s in states ]
437+ else :
438+ states = states .as_in_context (ctx )
433439 with mx .autograd .record ():
434440 if states is None :
435441 out = layer (inputs )
@@ -467,47 +473,76 @@ def check_rnn_layer_forward(layer, inputs, states=None, run_only=False):
467473 mx .test_utils .assert_almost_equal (np_dx , inputs .grad .asnumpy (), rtol = 1e-3 , atol = 1e-5 )
468474
469475
470- @assert_raises_cudnn_not_satisfied (min_version = '5.1.10' )
471- def test_rnn_layers ():
472- check_rnn_layer_forward (gluon .rnn .RNN (10 , 2 ), mx .nd .ones ((8 , 3 , 20 )))
473- check_rnn_layer_forward (gluon .rnn .RNN (10 , 2 , bidirectional = True ), mx .nd .ones ((8 , 3 , 20 )), mx .nd .ones ((4 , 3 , 10 )))
474- check_rnn_layer_forward (gluon .rnn .LSTM (10 , 2 ), mx .nd .ones ((8 , 3 , 20 )))
475- check_rnn_layer_forward (gluon .rnn .LSTM (10 , 2 , bidirectional = True ), mx .nd .ones ((8 , 3 , 20 )), [mx .nd .ones ((4 , 3 , 10 )), mx .nd .ones ((4 , 3 , 10 ))])
476- check_rnn_layer_forward (gluon .rnn .GRU (10 , 2 ), mx .nd .ones ((8 , 3 , 20 )))
477- check_rnn_layer_forward (gluon .rnn .GRU (10 , 2 , bidirectional = True ), mx .nd .ones ((8 , 3 , 20 )), mx .nd .ones ((4 , 3 , 10 )))
478-
479- check_rnn_layer_forward (gluon .rnn .RNN (10 , 2 , dropout = 0.5 ), mx .nd .ones ((8 , 3 , 20 )),
480- run_only = True )
481- check_rnn_layer_forward (gluon .rnn .RNN (10 , 2 , bidirectional = True , dropout = 0.5 ),
482- mx .nd .ones ((8 , 3 , 20 )), mx .nd .ones ((4 , 3 , 10 )), run_only = True )
483- check_rnn_layer_forward (gluon .rnn .LSTM (10 , 2 , dropout = 0.5 ), mx .nd .ones ((8 , 3 , 20 )),
484- run_only = True )
485- check_rnn_layer_forward (gluon .rnn .LSTM (10 , 2 , bidirectional = True , dropout = 0.5 ),
486- mx .nd .ones ((8 , 3 , 20 )),
487- [mx .nd .ones ((4 , 3 , 10 )), mx .nd .ones ((4 , 3 , 10 ))], run_only = True )
488- check_rnn_layer_forward (gluon .rnn .GRU (10 , 2 , dropout = 0.5 ), mx .nd .ones ((8 , 3 , 20 )),
489- run_only = True )
490- check_rnn_layer_forward (gluon .rnn .GRU (10 , 2 , bidirectional = True , dropout = 0.5 ),
491- mx .nd .ones ((8 , 3 , 20 )), mx .nd .ones ((4 , 3 , 10 )), run_only = True )
476+
477+ def run_rnn_layers (dtype , dtype2 , ctx = mx .cpu ()):
478+
479+ check_rnn_layer_forward (gluon .rnn .RNN (10 , 2 , dtype = dtype ), mx .nd .ones ((8 , 3 , 20 ), dtype = dtype ), ctx = ctx )
480+ check_rnn_layer_forward (gluon .rnn .RNN (10 , 2 , dtype = dtype , bidirectional = True ), mx .nd .ones ((8 , 3 , 20 ), dtype = dtype ), mx .nd .ones ((4 , 3 , 10 ), dtype = dtype ), ctx = ctx )
481+ check_rnn_layer_forward (gluon .rnn .LSTM (10 , 2 ,dtype = dtype ), mx .nd .ones ((8 , 3 , 20 ), dtype = dtype ), ctx = ctx )
482+ check_rnn_layer_forward (gluon .rnn .LSTM (10 , 2 ,dtype = dtype , bidirectional = True ), mx .nd .ones ((8 , 3 , 20 ), dtype = dtype ), [mx .nd .ones ((4 , 3 , 10 ), dtype = dtype ), mx .nd .ones ((4 , 3 , 10 ), dtype = dtype )],ctx = ctx )
483+ check_rnn_layer_forward (gluon .rnn .GRU (10 , 2 , dtype = dtype , ), mx .nd .ones ((8 , 3 , 20 ), dtype = dtype ),ctx = ctx )
484+ check_rnn_layer_forward (gluon .rnn .GRU (10 , 2 , dtype = dtype , bidirectional = True ), mx .nd .ones ((8 , 3 , 20 ), dtype = dtype ), mx .nd .ones ((4 , 3 , 10 ), dtype = dtype ),ctx = ctx )
485+
486+
487+ check_rnn_layer_forward (gluon .rnn .RNN (10 , 2 , dtype = dtype , dropout = 0.5 ), mx .nd .ones ((8 , 3 , 20 ), dtype = dtype ),
488+ run_only = True , ctx = ctx )
489+ check_rnn_layer_forward (gluon .rnn .RNN (10 , 2 , bidirectional = True , dropout = 0.5 , dtype = dtype ),
490+ mx .nd .ones ((8 , 3 , 20 ), dtype = dtype ), mx .nd .ones ((4 , 3 , 10 ), dtype = dtype ), run_only = True , ctx = ctx )
491+ check_rnn_layer_forward (gluon .rnn .LSTM (10 , 2 , dropout = 0.5 , dtype = dtype ), mx .nd .ones ((8 , 3 , 20 ), dtype = dtype ),
492+ run_only = True , ctx = ctx )
493+ check_rnn_layer_forward (gluon .rnn .LSTM (10 , 2 , bidirectional = True , dropout = 0.5 , dtype = dtype ),
494+ mx .nd .ones ((8 , 3 , 20 ), dtype = dtype ),
495+ [mx .nd .ones ((4 , 3 , 10 ), dtype = dtype ), mx .nd .ones ((4 , 3 , 10 ), dtype = dtype )], run_only = True , ctx = ctx )
496+ check_rnn_layer_forward (gluon .rnn .GRU (10 , 2 , dropout = 0.5 , dtype = dtype ), mx .nd .ones ((8 , 3 , 20 ), dtype = dtype ),
497+ run_only = True , ctx = ctx )
498+ check_rnn_layer_forward (gluon .rnn .GRU (10 , 2 , bidirectional = True , dropout = 0.5 , dtype = dtype ),
499+ mx .nd .ones ((8 , 3 , 20 ), dtype = dtype ), mx .nd .ones ((4 , 3 , 10 ), dtype = dtype ), run_only = True , ctx = ctx )
492500
493501 net = gluon .nn .Sequential ()
494- net .add (gluon .rnn .LSTM (10 , bidirectional = True ))
502+ net .add (gluon .rnn .LSTM (10 , bidirectional = True , dtype = dtype2 ))
495503 net .add (gluon .nn .BatchNorm (axis = 2 ))
496504 net .add (gluon .nn .Flatten ())
497505 net .add (gluon .nn .Dense (3 , activation = 'relu' ))
498- net .collect_params ().initialize ()
506+ net .collect_params ().initialize (ctx = ctx )
507+ net .cast (dtype )
499508 with mx .autograd .record ():
500- net (mx .nd .ones ((2 , 3 , 10 ))).backward ()
509+ out = net (mx .nd .ones ((2 , 3 , 10 ), dtype = dtype , ctx = ctx ))
510+ out .backward ()
511+ out = out .asnumpy ()
501512
502513 net2 = gluon .nn .HybridSequential ()
503- net2 .add (gluon .rnn .LSTM (10 , bidirectional = True ))
514+ net2 .add (gluon .rnn .LSTM (10 , bidirectional = True , dtype = dtype2 ))
504515 net2 .add (gluon .nn .BatchNorm (axis = 2 ))
505516 net2 .add (gluon .nn .Flatten ())
506517 net2 .add (gluon .nn .Dense (3 , activation = 'relu' ))
507518 net2 .hybridize ()
508- net2 .collect_params ().initialize ()
519+ net2 .collect_params ().initialize (ctx = ctx )
520+ net2 .cast (dtype )
521+ with mx .autograd .record ():
522+ out = net2 (mx .nd .ones ((2 , 3 , 10 ), dtype = dtype , ctx = ctx ))
523+ out .backward ()
524+ out = out .asnumpy ()
525+
526+ net3 = gluon .nn .HybridSequential ()
527+ net3 .add (gluon .rnn .LSTM (10 , bidirectional = True , dtype = dtype ))
528+ net3 .add (gluon .nn .BatchNorm (axis = 2 ))
529+ net3 .add (gluon .nn .Flatten ())
530+ net3 .add (gluon .nn .Dense (3 , activation = 'relu' ))
531+ net3 .hybridize ()
532+ net3 .collect_params ().initialize (ctx = ctx )
533+ net3 .cast (dtype2 )
509534 with mx .autograd .record ():
510- net2 (mx .nd .ones ((2 , 3 , 10 ))).backward ()
535+ out = net3 (mx .nd .ones ((2 , 3 , 10 ), dtype = dtype2 , ctx = ctx ))
536+ out .backward ()
537+ out = out .asnumpy ()
538+
539+ def test_rnn_layers_fp32 ():
540+ run_rnn_layers ('float32' , 'float32' )
541+
542+ @assert_raises_cudnn_not_satisfied (min_version = '5.1.10' )
543+ @unittest .skipIf (mx .context .num_gpus () == 0 , "RNN FP16 only implemented for GPU for now" )
544+ def test_rnn_layers_fp16 ():
545+ run_rnn_layers ('float16' , 'float32' , mx .gpu ())
511546
512547
513548def test_rnn_unroll_variant_length ():
@@ -590,8 +625,6 @@ def test_cell_fill_shape():
590625 check_rnn_forward (cell , mx .nd .ones ((2 , 3 , 7 )))
591626 assert cell .i2h_weight .shape [1 ] == 7 , cell .i2h_weight .shape [1 ]
592627
593-
594- @assert_raises_cudnn_not_satisfied (min_version = '5.1.10' )
595628def test_layer_fill_shape ():
596629 layer = gluon .rnn .LSTM (10 )
597630 layer .hybridize ()
@@ -603,6 +636,7 @@ def test_layer_fill_shape():
603636def test_bidirectional_unroll_valid_length ():
604637 # Test BidirectionalCell.
605638 # In 1.3.1 version, after hybridize( ), BidirectionalCell would failed when pass valid_length to unroll( ).
639+
606640 class BiLSTM (gluon .nn .HybridBlock ):
607641 def __init__ (self , rnn_size , time_step , ** kwargs ):
608642 super (BiLSTM , self ).__init__ (** kwargs )
0 commit comments