Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 09daf22

Browse files
haojin2eric-haibin-lin
authored andcommitted
speedup SequenceMask on GPU (#14445)
1 parent 5d2a451 commit 09daf22

File tree

3 files changed

+140
-62
lines changed

3 files changed

+140
-62
lines changed

src/operator/sequence_mask-inl.h

Lines changed: 17 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -65,70 +65,24 @@ struct SequenceMaskParam : public dmlc::Parameter<SequenceMaskParam> {
6565
}
6666
};
6767

68-
// (seqlen, batch, rest) case
69-
template <int req>
70-
struct SequenceMask0Kernel {
71-
template <typename DType, typename IType>
72-
MSHADOW_XINLINE static void Map(int b, DType *in, const IType *idx,
73-
index_t max_s_len, index_t batch_size,
74-
index_t restsize, DType value) {
75-
const index_t seqpos = static_cast<int>(idx[b]);
76-
#pragma unroll
77-
for (index_t s = seqpos; s < max_s_len; ++s) {
78-
index_t incr = (s * batch_size * restsize) + (b * restsize);
79-
#pragma unroll
80-
for (index_t r = 0; r < restsize; ++r)
81-
KERNEL_ASSIGN(in[incr + r], req, value);
82-
}
83-
}
84-
};
85-
86-
// (batch, seqlen, rest) case
87-
template <int req>
88-
struct SequenceMask1Kernel {
89-
template <typename DType, typename IType>
90-
MSHADOW_XINLINE static void Map(int b, DType *in, const IType *idx,
91-
index_t max_s_len, index_t batch_size,
92-
index_t restsize, DType value) {
93-
const index_t seqpos = static_cast<int>(idx[b]);
94-
#pragma unroll
95-
for (index_t s = seqpos; s < max_s_len; ++s) {
96-
index_t incr = (b * max_s_len * restsize) + (s * restsize);
97-
#pragma unroll
98-
for (index_t r = 0; r < restsize; ++r)
99-
KERNEL_ASSIGN(in[incr + r], req, value);
100-
}
101-
}
102-
};
68+
template<typename DType, typename IType>
69+
void SequenceMaskExec(const mshadow::Tensor<cpu, 3, DType> &data,
70+
const mshadow::Tensor<cpu, 1, IType> &indices,
71+
const OpReqType req, mshadow::Stream<cpu> *const s,
72+
int axis, DType val);
73+
#ifdef __CUDACC__
74+
template<typename DType, typename IType>
75+
void SequenceMaskExec(const mshadow::Tensor<gpu, 3, DType> &data,
76+
const mshadow::Tensor<gpu, 1, IType> &indices,
77+
const OpReqType req, mshadow::Stream<gpu> *const s,
78+
int axis, DType val);
79+
#endif
10380

10481
template <typename xpu, typename DType, typename IType>
10582
class SequenceMaskOp : public Operator {
10683
public:
10784
explicit SequenceMaskOp(SequenceMaskParam p) { this->param_ = p; }
10885

109-
void sequence_mask(const mshadow::Tensor<xpu, 3, DType> &data,
110-
const mshadow::Tensor<xpu, 1, IType> &indices,
111-
const OpReqType req, mshadow::Stream<xpu> *const s,
112-
DType val) {
113-
using namespace mshadow;
114-
using namespace mshadow::expr;
115-
116-
index_t batch = indices.size(0);
117-
index_t max_seq_len = data.size(param_.axis);
118-
index_t restsize = data.size(2);
119-
120-
MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
121-
if (param_.axis == 1)
122-
mxnet_op::Kernel<SequenceMask1Kernel<req_type>, xpu>::Launch(
123-
s, batch, data.dptr_, indices.dptr_, max_seq_len, batch, restsize,
124-
val);
125-
else
126-
mxnet_op::Kernel<SequenceMask0Kernel<req_type>, xpu>::Launch(
127-
s, batch, data.dptr_, indices.dptr_, max_seq_len, batch, restsize,
128-
val);
129-
});
130-
}
131-
13286
virtual void Forward(const OpContext &ctx, const std::vector<TBlob> &in_data,
13387
const std::vector<OpReqType> &req,
13488
const std::vector<TBlob> &out_data,
@@ -155,8 +109,8 @@ class SequenceMaskOp : public Operator {
155109
if (param_.use_sequence_length) {
156110
Tensor<xpu, 1, IType> indices =
157111
in_data[seq_mask::kSequenceLength].get<xpu, 1, IType>(s);
158-
sequence_mask(out, indices, req[seq_mask::kOut], s,
159-
static_cast<DType>(param_.value));
112+
SequenceMaskExec<DType, IType>(out, indices, req[seq_mask::kOut], s,
113+
param_.axis, static_cast<DType>(param_.value));
160114
}
161115
}
162116

@@ -198,11 +152,12 @@ class SequenceMaskOp : public Operator {
198152
s3, s);
199153
out_g_temp = F<mshadow_op::identity>(out_g);
200154
out_g = out_g_temp;
201-
sequence_mask(out_g, indices, kWriteInplace, s, DType(0.));
155+
SequenceMaskExec<DType, IType>(out_g, indices, kWriteInplace, s, param_.axis, DType(0.));
202156
Assign(data_g, kAddTo, F<mshadow_op::identity>(out_g));
203157
} else {
204158
Assign(data_g, req[seq_mask::kData], F<mshadow_op::identity>(out_g));
205-
sequence_mask(data_g, indices, req[seq_mask::kData], s, DType(0.));
159+
SequenceMaskExec<DType, IType>(
160+
data_g, indices, req[seq_mask::kData], s, param_.axis, DType(0.));
206161
}
207162
}
208163
}

src/operator/sequence_mask.cc

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,70 @@
2727

2828
namespace mxnet {
2929
namespace op {
30+
31+
// (seqlen, batch, rest) case
32+
template <int req>
33+
struct SequenceMask0CPUKernel {
34+
template <typename DType, typename IType>
35+
MSHADOW_XINLINE static void Map(int batch, DType *in, const IType *idx,
36+
index_t max_s_len, index_t batch_size,
37+
index_t restsize, DType value) {
38+
const index_t seqpos = static_cast<int>(idx[batch]);
39+
#pragma unroll
40+
for (index_t s = seqpos; s < max_s_len; ++s) {
41+
index_t incr = (s * batch_size * restsize) + (batch * restsize);
42+
#pragma unroll
43+
for (index_t r = 0; r < restsize; ++r)
44+
KERNEL_ASSIGN(in[incr + r], req, value);
45+
}
46+
}
47+
};
48+
49+
// (batch, seqlen, rest) case
50+
template <int req>
51+
struct SequenceMask1CPUKernel {
52+
template <typename DType, typename IType>
53+
MSHADOW_XINLINE static void Map(int batch, DType *in, const IType *idx,
54+
index_t max_s_len, index_t batch_size,
55+
index_t restsize, DType value) {
56+
const index_t seqpos = static_cast<int>(idx[batch]);
57+
#pragma unroll
58+
for (index_t s = seqpos; s < max_s_len; ++s) {
59+
index_t incr = (batch * max_s_len * restsize) + (s * restsize);
60+
#pragma unroll
61+
for (index_t r = 0; r < restsize; ++r)
62+
KERNEL_ASSIGN(in[incr + r], req, value);
63+
}
64+
}
65+
};
66+
67+
template<typename DType, typename IType>
68+
void SequenceMaskExec(
69+
const mshadow::Tensor<cpu, 3, DType> &data,
70+
const mshadow::Tensor<cpu, 1, IType> &indices,
71+
const OpReqType req, mshadow::Stream<cpu> *const s,
72+
int axis, DType val) {
73+
using namespace mshadow;
74+
using namespace mshadow::expr;
75+
using namespace mxnet_op;
76+
77+
index_t batch = indices.size(0);
78+
index_t max_seq_len = data.size(axis);
79+
index_t restsize = data.size(2);
80+
81+
MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
82+
if (axis == 1) {
83+
Kernel<SequenceMask1CPUKernel<req_type>, cpu>::Launch(
84+
s, batch, data.dptr_, indices.dptr_, max_seq_len, batch, restsize,
85+
val);
86+
} else {
87+
Kernel<SequenceMask0CPUKernel<req_type>, cpu>::Launch(
88+
s, batch, data.dptr_, indices.dptr_, max_seq_len, batch, restsize,
89+
val);
90+
}
91+
});
92+
}
93+
3094
template <>
3195
Operator *CreateOp<cpu>(SequenceMaskParam param, int dtype, int itype) {
3296
Operator *op = nullptr;

src/operator/sequence_mask.cu

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,65 @@
2929
namespace mxnet {
3030
namespace op {
3131

32+
// (seqlen, batch, rest) case
33+
template <int req>
34+
struct SequenceMask0GPUKernel {
35+
template <typename DType, typename IType>
36+
MSHADOW_XINLINE static void Map(int i, DType *in, const IType *idx,
37+
index_t max_s_len, index_t batch_size,
38+
index_t restsize, DType value) {
39+
index_t batch = i / restsize % batch_size;
40+
const index_t seqpos = static_cast<int>(idx[batch]);
41+
index_t seq = i / restsize / batch_size;
42+
if (seq >= seqpos) {
43+
KERNEL_ASSIGN(in[i], req, value);
44+
}
45+
}
46+
};
47+
48+
// (batch, seqlen, rest) case
49+
template <int req>
50+
struct SequenceMask1GPUKernel {
51+
template <typename DType, typename IType>
52+
MSHADOW_XINLINE static void Map(int i, DType *in, const IType *idx,
53+
index_t max_s_len, index_t batch_size,
54+
index_t restsize, DType value) {
55+
index_t batch = i / restsize / max_s_len;
56+
const index_t seqpos = static_cast<int>(idx[batch]);
57+
index_t seq = i / restsize % max_s_len;
58+
if (seq >= seqpos) {
59+
KERNEL_ASSIGN(in[i], req, value);
60+
}
61+
}
62+
};
63+
64+
template<typename DType, typename IType>
65+
void SequenceMaskExec(
66+
const mshadow::Tensor<gpu, 3, DType> &data,
67+
const mshadow::Tensor<gpu, 1, IType> &indices,
68+
const OpReqType req, mshadow::Stream<gpu> *const s,
69+
int axis, DType val) {
70+
using namespace mshadow;
71+
using namespace mshadow::expr;
72+
using namespace mxnet_op;
73+
74+
index_t batch = indices.size(0);
75+
index_t max_seq_len = data.size(axis);
76+
index_t restsize = data.size(2);
77+
78+
MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
79+
if (axis == 1) {
80+
Kernel<SequenceMask1GPUKernel<req_type>, gpu>::Launch(
81+
s, data.shape_.Size(), data.dptr_, indices.dptr_, max_seq_len, batch, restsize,
82+
val);
83+
} else {
84+
Kernel<SequenceMask0GPUKernel<req_type>, gpu>::Launch(
85+
s, data.shape_.Size(), data.dptr_, indices.dptr_, max_seq_len, batch, restsize,
86+
val);
87+
}
88+
});
89+
}
90+
3291
template <> Operator *CreateOp<gpu>(SequenceMaskParam param, int dtype, int itype) {
3392
Operator *op = NULL;
3493
MSHADOW_TYPE_SWITCH(dtype, DType, {

0 commit comments

Comments
 (0)