Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 6aa8c27

Browse files
ThomasDelteileric-haibin-lin
authored andcommitted
[MXNET-1327] Allow RNN Layers to be initialized to fp16 (#14219)
* update rnn for fp16 * fix typo in test * fix tests * fix tests * fix gpu tests * Update test_gluon_rnn.py * Update test_gluon_rnn.py * trigger * try removing checks for unix
1 parent 66c74cc commit 6aa8c27

File tree

2 files changed

+101
-56
lines changed

2 files changed

+101
-56
lines changed

python/mxnet/gluon/rnn/rnn_layer.py

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def __init__(self, hidden_size, num_layers, layout,
3737
i2h_bias_initializer, h2h_bias_initializer,
3838
mode, projection_size, h2r_weight_initializer,
3939
lstm_state_clip_min, lstm_state_clip_max, lstm_state_clip_nan,
40-
**kwargs):
40+
dtype, **kwargs):
4141
super(_RNNLayer, self).__init__(**kwargs)
4242
assert layout in ('TNC', 'NTC'), \
4343
"Invalid layout %s; must be one of ['TNC' or 'NTC']"%layout
@@ -57,6 +57,7 @@ def __init__(self, hidden_size, num_layers, layout,
5757
self._lstm_state_clip_min = lstm_state_clip_min
5858
self._lstm_state_clip_max = lstm_state_clip_max
5959
self._lstm_state_clip_nan = lstm_state_clip_nan
60+
self._dtype = dtype
6061

6162
self._gates = {'rnn_relu': 1, 'rnn_tanh': 1, 'lstm': 4, 'gru': 3}[mode]
6263

@@ -66,41 +67,41 @@ def __init__(self, hidden_size, num_layers, layout,
6667
for j in ['l', 'r'][:self._dir]:
6768
self._register_param('{}{}_i2h_weight'.format(j, i),
6869
shape=(ng*nh, ni),
69-
init=i2h_weight_initializer)
70+
init=i2h_weight_initializer, dtype=dtype)
7071
self._register_param('{}{}_h2h_weight'.format(j, i),
7172
shape=(ng*nh, nh),
72-
init=h2h_weight_initializer)
73+
init=h2h_weight_initializer, dtype=dtype)
7374
self._register_param('{}{}_i2h_bias'.format(j, i),
7475
shape=(ng*nh,),
75-
init=i2h_bias_initializer)
76+
init=i2h_bias_initializer, dtype=dtype)
7677
self._register_param('{}{}_h2h_bias'.format(j, i),
7778
shape=(ng*nh,),
78-
init=h2h_bias_initializer)
79+
init=h2h_bias_initializer, dtype=dtype)
7980
ni = nh * self._dir
8081
else:
8182
np = self._projection_size
8283
for i in range(num_layers):
8384
for j in ['l', 'r'][:self._dir]:
8485
self._register_param('{}{}_i2h_weight'.format(j, i),
8586
shape=(ng*nh, ni),
86-
init=i2h_weight_initializer)
87+
init=i2h_weight_initializer, dtype=dtype)
8788
self._register_param('{}{}_h2h_weight'.format(j, i),
8889
shape=(ng*nh, np),
89-
init=h2h_weight_initializer)
90+
init=h2h_weight_initializer, dtype=dtype)
9091
self._register_param('{}{}_i2h_bias'.format(j, i),
9192
shape=(ng*nh,),
92-
init=i2h_bias_initializer)
93+
init=i2h_bias_initializer, dtype=dtype)
9394
self._register_param('{}{}_h2h_bias'.format(j, i),
9495
shape=(ng*nh,),
95-
init=h2h_bias_initializer)
96+
init=h2h_bias_initializer, dtype=dtype)
9697
self._register_param('{}{}_h2r_weight'.format(j, i),
9798
shape=(np, nh),
98-
init=h2r_weight_initializer)
99+
init=h2r_weight_initializer, dtype=dtype)
99100
ni = np * self._dir
100101

101-
def _register_param(self, name, shape, init):
102+
def _register_param(self, name, shape, init, dtype):
102103
p = self.params.get(name, shape=shape, init=init,
103-
allow_deferred_init=True)
104+
allow_deferred_init=True, dtype=dtype)
104105
setattr(self, name, p)
105106
return p
106107

@@ -179,6 +180,10 @@ def _unfuse(self):
179180

180181
return stack
181182

183+
def cast(self, dtype):
184+
super(_RNNLayer, self).cast(dtype)
185+
self._dtype = dtype
186+
182187
def begin_state(self, batch_size=0, func=ndarray.zeros, **kwargs):
183188
"""Initial state for this cell.
184189
@@ -317,6 +322,8 @@ class RNN(_RNNLayer):
317322
input_size: int, default 0
318323
The number of expected features in the input x.
319324
If not specified, it will be inferred from input.
325+
dtype : str, default 'float32'
326+
Type to initialize the parameters and default states to
320327
prefix : str or None
321328
Prefix of this `Block`.
322329
params : ParameterDict or None
@@ -357,17 +364,17 @@ def __init__(self, hidden_size, num_layers=1, activation='relu',
357364
layout='TNC', dropout=0, bidirectional=False,
358365
i2h_weight_initializer=None, h2h_weight_initializer=None,
359366
i2h_bias_initializer='zeros', h2h_bias_initializer='zeros',
360-
input_size=0, **kwargs):
367+
input_size=0, dtype='float32', **kwargs):
361368
super(RNN, self).__init__(hidden_size, num_layers, layout,
362369
dropout, bidirectional, input_size,
363370
i2h_weight_initializer, h2h_weight_initializer,
364371
i2h_bias_initializer, h2h_bias_initializer,
365372
'rnn_'+activation, None, None, None, None, False,
366-
**kwargs)
373+
dtype, **kwargs)
367374

368375
def state_info(self, batch_size=0):
369376
return [{'shape': (self._num_layers * self._dir, batch_size, self._hidden_size),
370-
'__layout__': 'LNC'}]
377+
'__layout__': 'LNC', 'dtype': self._dtype}]
371378

372379

373380
class LSTM(_RNNLayer):
@@ -432,6 +439,8 @@ class LSTM(_RNNLayer):
432439
state_clip_nan : boolean, default False
433440
Whether to stop NaN from propagating in state by clipping it to min/max.
434441
If the clipping range is not specified, this option is ignored.
442+
dtype : str, default 'float32'
443+
Type to initialize the parameters and default states to
435444
input_size: int, default 0
436445
The number of expected features in the input x.
437446
If not specified, it will be inferred from input.
@@ -477,26 +486,26 @@ def __init__(self, hidden_size, num_layers=1, layout='TNC',
477486
i2h_bias_initializer='zeros', h2h_bias_initializer='zeros',
478487
projection_size=None, h2r_weight_initializer=None,
479488
state_clip_min=None, state_clip_max=None, state_clip_nan=False,
480-
**kwargs):
489+
dtype='float32', **kwargs):
481490
super(LSTM, self).__init__(hidden_size, num_layers, layout,
482491
dropout, bidirectional, input_size,
483492
i2h_weight_initializer, h2h_weight_initializer,
484493
i2h_bias_initializer, h2h_bias_initializer,
485494
'lstm', projection_size, h2r_weight_initializer,
486495
state_clip_min, state_clip_max, state_clip_nan,
487-
**kwargs)
496+
dtype, **kwargs)
488497

489498
def state_info(self, batch_size=0):
490499
if self._projection_size is None:
491500
return [{'shape': (self._num_layers * self._dir, batch_size, self._hidden_size),
492-
'__layout__': 'LNC'},
501+
'__layout__': 'LNC', 'dtype': self._dtype},
493502
{'shape': (self._num_layers * self._dir, batch_size, self._hidden_size),
494-
'__layout__': 'LNC'}]
503+
'__layout__': 'LNC', 'dtype': self._dtype}]
495504
else:
496505
return [{'shape': (self._num_layers * self._dir, batch_size, self._projection_size),
497-
'__layout__': 'LNC'},
506+
'__layout__': 'LNC', 'dtype': self._dtype},
498507
{'shape': (self._num_layers * self._dir, batch_size, self._hidden_size),
499-
'__layout__': 'LNC'}]
508+
'__layout__': 'LNC', 'dtype': self._dtype}]
500509

501510

502511
class GRU(_RNNLayer):
@@ -544,6 +553,8 @@ class GRU(_RNNLayer):
544553
Initializer for the bias vector.
545554
h2h_bias_initializer : str or Initializer
546555
Initializer for the bias vector.
556+
dtype : str, default 'float32'
557+
Type to initialize the parameters and default states to
547558
input_size: int, default 0
548559
The number of expected features in the input x.
549560
If not specified, it will be inferred from input.
@@ -586,14 +597,14 @@ def __init__(self, hidden_size, num_layers=1, layout='TNC',
586597
dropout=0, bidirectional=False, input_size=0,
587598
i2h_weight_initializer=None, h2h_weight_initializer=None,
588599
i2h_bias_initializer='zeros', h2h_bias_initializer='zeros',
589-
**kwargs):
600+
dtype='float32', **kwargs):
590601
super(GRU, self).__init__(hidden_size, num_layers, layout,
591602
dropout, bidirectional, input_size,
592603
i2h_weight_initializer, h2h_weight_initializer,
593604
i2h_bias_initializer, h2h_bias_initializer,
594605
'gru', None, None, None, None, False,
595-
**kwargs)
606+
dtype, **kwargs)
596607

597608
def state_info(self, batch_size=0):
598609
return [{'shape': (self._num_layers * self._dir, batch_size, self._hidden_size),
599-
'__layout__': 'LNC'}]
610+
'__layout__': 'LNC', 'dtype': self._dtype}]

tests/python/unittest/test_gluon_rnn.py

Lines changed: 66 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -427,9 +427,15 @@ def hybrid_forward(self, F, seq):
427427
assert_almost_equal(output1.asnumpy(), output2.asnumpy())
428428

429429

430-
def check_rnn_layer_forward(layer, inputs, states=None, run_only=False):
431-
layer.collect_params().initialize()
430+
def check_rnn_layer_forward(layer, inputs, states=None, run_only=False, ctx=mx.cpu()):
431+
layer.collect_params().initialize(ctx=ctx)
432+
inputs = inputs.as_in_context(ctx)
432433
inputs.attach_grad()
434+
if states is not None:
435+
if isinstance(states, (list, tuple)):
436+
states = [s.as_in_context(ctx) for s in states]
437+
else:
438+
states = states.as_in_context(ctx)
433439
with mx.autograd.record():
434440
if states is None:
435441
out = layer(inputs)
@@ -467,47 +473,76 @@ def check_rnn_layer_forward(layer, inputs, states=None, run_only=False):
467473
mx.test_utils.assert_almost_equal(np_dx, inputs.grad.asnumpy(), rtol=1e-3, atol=1e-5)
468474

469475

470-
@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
471-
def test_rnn_layers():
472-
check_rnn_layer_forward(gluon.rnn.RNN(10, 2), mx.nd.ones((8, 3, 20)))
473-
check_rnn_layer_forward(gluon.rnn.RNN(10, 2, bidirectional=True), mx.nd.ones((8, 3, 20)), mx.nd.ones((4, 3, 10)))
474-
check_rnn_layer_forward(gluon.rnn.LSTM(10, 2), mx.nd.ones((8, 3, 20)))
475-
check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, bidirectional=True), mx.nd.ones((8, 3, 20)), [mx.nd.ones((4, 3, 10)), mx.nd.ones((4, 3, 10))])
476-
check_rnn_layer_forward(gluon.rnn.GRU(10, 2), mx.nd.ones((8, 3, 20)))
477-
check_rnn_layer_forward(gluon.rnn.GRU(10, 2, bidirectional=True), mx.nd.ones((8, 3, 20)), mx.nd.ones((4, 3, 10)))
478-
479-
check_rnn_layer_forward(gluon.rnn.RNN(10, 2, dropout=0.5), mx.nd.ones((8, 3, 20)),
480-
run_only=True)
481-
check_rnn_layer_forward(gluon.rnn.RNN(10, 2, bidirectional=True, dropout=0.5),
482-
mx.nd.ones((8, 3, 20)), mx.nd.ones((4, 3, 10)), run_only=True)
483-
check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, dropout=0.5), mx.nd.ones((8, 3, 20)),
484-
run_only=True)
485-
check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, bidirectional=True, dropout=0.5),
486-
mx.nd.ones((8, 3, 20)),
487-
[mx.nd.ones((4, 3, 10)), mx.nd.ones((4, 3, 10))], run_only=True)
488-
check_rnn_layer_forward(gluon.rnn.GRU(10, 2, dropout=0.5), mx.nd.ones((8, 3, 20)),
489-
run_only=True)
490-
check_rnn_layer_forward(gluon.rnn.GRU(10, 2, bidirectional=True, dropout=0.5),
491-
mx.nd.ones((8, 3, 20)), mx.nd.ones((4, 3, 10)), run_only=True)
476+
477+
def run_rnn_layers(dtype, dtype2, ctx=mx.cpu()):
478+
479+
check_rnn_layer_forward(gluon.rnn.RNN(10, 2, dtype=dtype), mx.nd.ones((8, 3, 20), dtype=dtype), ctx=ctx)
480+
check_rnn_layer_forward(gluon.rnn.RNN(10, 2, dtype=dtype, bidirectional=True), mx.nd.ones((8, 3, 20), dtype=dtype), mx.nd.ones((4, 3, 10), dtype=dtype), ctx=ctx)
481+
check_rnn_layer_forward(gluon.rnn.LSTM(10, 2,dtype=dtype), mx.nd.ones((8, 3, 20), dtype=dtype), ctx=ctx)
482+
check_rnn_layer_forward(gluon.rnn.LSTM(10, 2,dtype=dtype, bidirectional=True), mx.nd.ones((8, 3, 20), dtype=dtype), [mx.nd.ones((4, 3, 10), dtype=dtype), mx.nd.ones((4, 3, 10), dtype=dtype)],ctx=ctx)
483+
check_rnn_layer_forward(gluon.rnn.GRU(10, 2, dtype=dtype, ), mx.nd.ones((8, 3, 20), dtype=dtype),ctx=ctx)
484+
check_rnn_layer_forward(gluon.rnn.GRU(10, 2, dtype=dtype, bidirectional=True), mx.nd.ones((8, 3, 20), dtype=dtype), mx.nd.ones((4, 3, 10), dtype=dtype),ctx=ctx)
485+
486+
487+
check_rnn_layer_forward(gluon.rnn.RNN(10, 2, dtype=dtype, dropout=0.5), mx.nd.ones((8, 3, 20), dtype=dtype),
488+
run_only=True, ctx=ctx)
489+
check_rnn_layer_forward(gluon.rnn.RNN(10, 2, bidirectional=True, dropout=0.5, dtype=dtype),
490+
mx.nd.ones((8, 3, 20), dtype=dtype), mx.nd.ones((4, 3, 10), dtype=dtype), run_only=True, ctx=ctx)
491+
check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, dropout=0.5, dtype=dtype), mx.nd.ones((8, 3, 20), dtype=dtype),
492+
run_only=True, ctx=ctx)
493+
check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, bidirectional=True, dropout=0.5, dtype=dtype),
494+
mx.nd.ones((8, 3, 20), dtype=dtype),
495+
[mx.nd.ones((4, 3, 10), dtype=dtype), mx.nd.ones((4, 3, 10), dtype=dtype)], run_only=True, ctx=ctx)
496+
check_rnn_layer_forward(gluon.rnn.GRU(10, 2, dropout=0.5, dtype=dtype), mx.nd.ones((8, 3, 20), dtype=dtype),
497+
run_only=True, ctx=ctx)
498+
check_rnn_layer_forward(gluon.rnn.GRU(10, 2, bidirectional=True, dropout=0.5, dtype=dtype),
499+
mx.nd.ones((8, 3, 20), dtype=dtype), mx.nd.ones((4, 3, 10), dtype=dtype), run_only=True, ctx=ctx)
492500

493501
net = gluon.nn.Sequential()
494-
net.add(gluon.rnn.LSTM(10, bidirectional=True))
502+
net.add(gluon.rnn.LSTM(10, bidirectional=True, dtype=dtype2))
495503
net.add(gluon.nn.BatchNorm(axis=2))
496504
net.add(gluon.nn.Flatten())
497505
net.add(gluon.nn.Dense(3, activation='relu'))
498-
net.collect_params().initialize()
506+
net.collect_params().initialize(ctx=ctx)
507+
net.cast(dtype)
499508
with mx.autograd.record():
500-
net(mx.nd.ones((2, 3, 10))).backward()
509+
out = net(mx.nd.ones((2, 3, 10), dtype=dtype, ctx=ctx))
510+
out.backward()
511+
out = out.asnumpy()
501512

502513
net2 = gluon.nn.HybridSequential()
503-
net2.add(gluon.rnn.LSTM(10, bidirectional=True))
514+
net2.add(gluon.rnn.LSTM(10, bidirectional=True, dtype=dtype2))
504515
net2.add(gluon.nn.BatchNorm(axis=2))
505516
net2.add(gluon.nn.Flatten())
506517
net2.add(gluon.nn.Dense(3, activation='relu'))
507518
net2.hybridize()
508-
net2.collect_params().initialize()
519+
net2.collect_params().initialize(ctx=ctx)
520+
net2.cast(dtype)
521+
with mx.autograd.record():
522+
out = net2(mx.nd.ones((2, 3, 10), dtype=dtype, ctx=ctx))
523+
out.backward()
524+
out = out.asnumpy()
525+
526+
net3 = gluon.nn.HybridSequential()
527+
net3.add(gluon.rnn.LSTM(10, bidirectional=True, dtype=dtype))
528+
net3.add(gluon.nn.BatchNorm(axis=2))
529+
net3.add(gluon.nn.Flatten())
530+
net3.add(gluon.nn.Dense(3, activation='relu'))
531+
net3.hybridize()
532+
net3.collect_params().initialize(ctx=ctx)
533+
net3.cast(dtype2)
509534
with mx.autograd.record():
510-
net2(mx.nd.ones((2, 3, 10))).backward()
535+
out = net3(mx.nd.ones((2, 3, 10), dtype=dtype2, ctx=ctx))
536+
out.backward()
537+
out = out.asnumpy()
538+
539+
def test_rnn_layers_fp32():
540+
run_rnn_layers('float32', 'float32')
541+
542+
@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
543+
@unittest.skipIf(mx.context.num_gpus() == 0, "RNN FP16 only implemented for GPU for now")
544+
def test_rnn_layers_fp16():
545+
run_rnn_layers('float16', 'float32', mx.gpu())
511546

512547

513548
def test_rnn_unroll_variant_length():
@@ -590,8 +625,6 @@ def test_cell_fill_shape():
590625
check_rnn_forward(cell, mx.nd.ones((2, 3, 7)))
591626
assert cell.i2h_weight.shape[1] == 7, cell.i2h_weight.shape[1]
592627

593-
594-
@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
595628
def test_layer_fill_shape():
596629
layer = gluon.rnn.LSTM(10)
597630
layer.hybridize()
@@ -603,6 +636,7 @@ def test_layer_fill_shape():
603636
def test_bidirectional_unroll_valid_length():
604637
# Test BidirectionalCell.
605638
# In 1.3.1 version, after hybridize( ), BidirectionalCell would failed when pass valid_length to unroll( ).
639+
606640
class BiLSTM(gluon.nn.HybridBlock):
607641
def __init__(self, rnn_size, time_step, **kwargs):
608642
super(BiLSTM, self).__init__(**kwargs)

0 commit comments

Comments
 (0)