@@ -65,70 +65,24 @@ struct SequenceMaskParam : public dmlc::Parameter<SequenceMaskParam> {
6565 }
6666};
6767
68- // (seqlen, batch, rest) case
69- template <int req>
70- struct SequenceMask0Kernel {
71- template <typename DType, typename IType>
72- MSHADOW_XINLINE static void Map (int b, DType *in, const IType *idx,
73- index_t max_s_len, index_t batch_size,
74- index_t restsize, DType value) {
75- const index_t seqpos = static_cast <int >(idx[b]);
76- #pragma unroll
77- for (index_t s = seqpos; s < max_s_len; ++s) {
78- index_t incr = (s * batch_size * restsize) + (b * restsize);
79- #pragma unroll
80- for (index_t r = 0 ; r < restsize; ++r)
81- KERNEL_ASSIGN (in[incr + r], req, value);
82- }
83- }
84- };
85-
86- // (batch, seqlen, rest) case
87- template <int req>
88- struct SequenceMask1Kernel {
89- template <typename DType, typename IType>
90- MSHADOW_XINLINE static void Map (int b, DType *in, const IType *idx,
91- index_t max_s_len, index_t batch_size,
92- index_t restsize, DType value) {
93- const index_t seqpos = static_cast <int >(idx[b]);
94- #pragma unroll
95- for (index_t s = seqpos; s < max_s_len; ++s) {
96- index_t incr = (b * max_s_len * restsize) + (s * restsize);
97- #pragma unroll
98- for (index_t r = 0 ; r < restsize; ++r)
99- KERNEL_ASSIGN (in[incr + r], req, value);
100- }
101- }
102- };
68+ template <typename DType, typename IType>
69+ void SequenceMaskExec (const mshadow::Tensor<cpu, 3 , DType> &data,
70+ const mshadow::Tensor<cpu, 1 , IType> &indices,
71+ const OpReqType req, mshadow::Stream<cpu> *const s,
72+ int axis, DType val);
73+ #ifdef __CUDACC__
74+ template <typename DType, typename IType>
75+ void SequenceMaskExec (const mshadow::Tensor<gpu, 3 , DType> &data,
76+ const mshadow::Tensor<gpu, 1 , IType> &indices,
77+ const OpReqType req, mshadow::Stream<gpu> *const s,
78+ int axis, DType val);
79+ #endif
10380
10481template <typename xpu, typename DType, typename IType>
10582class SequenceMaskOp : public Operator {
10683 public:
10784 explicit SequenceMaskOp (SequenceMaskParam p) { this ->param_ = p; }
10885
109- void sequence_mask (const mshadow::Tensor<xpu, 3 , DType> &data,
110- const mshadow::Tensor<xpu, 1 , IType> &indices,
111- const OpReqType req, mshadow::Stream<xpu> *const s,
112- DType val) {
113- using namespace mshadow ;
114- using namespace mshadow ::expr;
115-
116- index_t batch = indices.size (0 );
117- index_t max_seq_len = data.size (param_.axis );
118- index_t restsize = data.size (2 );
119-
120- MXNET_ASSIGN_REQ_SWITCH (req, req_type, {
121- if (param_.axis == 1 )
122- mxnet_op::Kernel<SequenceMask1Kernel<req_type>, xpu>::Launch (
123- s, batch, data.dptr_ , indices.dptr_ , max_seq_len, batch, restsize,
124- val);
125- else
126- mxnet_op::Kernel<SequenceMask0Kernel<req_type>, xpu>::Launch (
127- s, batch, data.dptr_ , indices.dptr_ , max_seq_len, batch, restsize,
128- val);
129- });
130- }
131-
13286 virtual void Forward (const OpContext &ctx, const std::vector<TBlob> &in_data,
13387 const std::vector<OpReqType> &req,
13488 const std::vector<TBlob> &out_data,
@@ -155,8 +109,8 @@ class SequenceMaskOp : public Operator {
155109 if (param_.use_sequence_length ) {
156110 Tensor<xpu, 1 , IType> indices =
157111 in_data[seq_mask::kSequenceLength ].get <xpu, 1 , IType>(s);
158- sequence_mask (out, indices, req[seq_mask::kOut ], s,
159- static_cast <DType>(param_.value ));
112+ SequenceMaskExec<DType, IType> (out, indices, req[seq_mask::kOut ], s,
113+ param_. axis , static_cast <DType>(param_.value ));
160114 }
161115 }
162116
@@ -198,11 +152,12 @@ class SequenceMaskOp : public Operator {
198152 s3, s);
199153 out_g_temp = F<mshadow_op::identity>(out_g);
200154 out_g = out_g_temp;
201- sequence_mask (out_g, indices, kWriteInplace , s, DType (0 .));
155+ SequenceMaskExec<DType, IType> (out_g, indices, kWriteInplace , s, param_. axis , DType (0 .));
202156 Assign (data_g, kAddTo , F<mshadow_op::identity>(out_g));
203157 } else {
204158 Assign (data_g, req[seq_mask::kData ], F<mshadow_op::identity>(out_g));
205- sequence_mask (data_g, indices, req[seq_mask::kData ], s, DType (0 .));
159+ SequenceMaskExec<DType, IType>(
160+ data_g, indices, req[seq_mask::kData ], s, param_.axis , DType (0 .));
206161 }
207162 }
208163 }
0 commit comments