Skip to content
This repository was archived by the owner on Jan 15, 2024. It is now read-only.
This repository was archived by the owner on Jan 15, 2024. It is now read-only.

Making transformer encoder fully hybridized for export #789

@eric-haibin-lin

Description

@eric-haibin-lin

Currently BERT can be exported with static length support. BERTEncoder inherits Transformer encoder, which contains a few .shape API calls, making the transformer encoder not fully hybridizable and create issues during export. As the result, the exported model only supports static length. These calls are located here:
https://github.com/dmlc/gluon-nlp/blob/master/src/gluonnlp/model/transformer.py#L450-L463

We need to remove these calls to export a model that supports variable length.

handling arange

In particular, we have

length = inputs.shape[1]
arange = mx.nd.arange(length, ctx=valid_length.context, dtype=valid_length.dtype)

To remove these .shape calls, we have 2 options:

contrib.arange_like op with ndarray input

Instead of contrib.arange(arr.shape[1], ...), we can introduce an arange_like op:

  • input: arr with shape (x,) and abitrary data
  • output: an output with shape arr.shape, and value of [0, 1, 2, ... size(arr) - 1].

With this op, we just need to slice the inputs on axis 1 and pass it to arange_like op:

arr = inputs.slice(begin=(0,0,0), end=(0,None,0)
arange = F.contrib.arange(arr)

control flow op

Alternatively, we can use control flow op (either foreach, or while loop) to loop N times, where N = inputs.shape[1]. Loop i fills in the value i in the output "arange" array.

However, this may have high overhead when N is large (512).

handling other .shape calls

  • mask = mx.nd.broadcast_axes(mx.nd.expand_dims(mask, axis=1), axis=1, size=length) can be replaced with broadcast_mul op with ones_like(arr)
  • inputs * math.sqrt(inputs.shape[-1]) can be replace with shape_nd op

@TaoLv

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions