Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit ba9ec22

Browse files
committed
Add GPU version of boolean_mask op
1 parent 26ca37c commit ba9ec22

File tree

4 files changed

+275
-70
lines changed

4 files changed

+275
-70
lines changed

src/operator/contrib/boolean_mask-inl.h

Lines changed: 2 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -55,78 +55,14 @@ inline void BooleanMaskForward(const nnvm::NodeAttrs& attrs,
5555
const OpContext &ctx,
5656
const std::vector<NDArray> &inputs,
5757
const std::vector<OpReqType> &req,
58-
const std::vector<NDArray> &outputs) {
59-
// TODO(@junrushao1994): This implementation is a proof-of-concept,
60-
// hence very slow actually. Performance should be improved in the future.
61-
CHECK_EQ(inputs.size(), 2U);
62-
CHECK_EQ(outputs.size(), 1U);
63-
const BooleanMaskParam& param = nnvm::get<BooleanMaskParam>(attrs.parsed);
64-
const int axis = param.axis;
65-
const NDArray &data = inputs[0];
66-
const NDArray &idx = inputs[1];
67-
const NDArray &out = outputs[0];
68-
CHECK_EQ(axis, 0) << "Not supported yet";
69-
CHECK_EQ(data.shape()[axis], idx.shape()[0]);
70-
CHECK_EQ(idx.shape().ndim(), 1U);
71-
// count the number of 1s in `idx`, so that we could know the output dimension
72-
size_t valid_num = 0;
73-
MSHADOW_TYPE_SWITCH(idx.dtype(), DType, {
74-
DType* idx_dptr = idx.data().dptr<DType>();
75-
int length = idx.shape()[0];
76-
for (int i = 0; i < length; i++) {
77-
if (idx_dptr[i]) {
78-
++valid_num;
79-
}
80-
}
81-
});
82-
// set the output shape forcefully
83-
TShape s = data.shape();
84-
s[axis] = valid_num;
85-
const_cast<NDArray &>(out).Init(s);
86-
// do the copy
87-
MSHADOW_TYPE_SWITCH(idx.dtype(), DType, {
88-
DType* idx_dptr = idx.data().dptr<DType>();
89-
int length = idx.shape()[0];
90-
mshadow::Stream<xpu> *stream = ctx.get_stream<xpu>();
91-
for (int i = 0, j = 0; i < length; ++i) {
92-
if (idx_dptr[i]) {
93-
NDArray src = data.At(i);
94-
NDArray dst = out.At(j++);
95-
CHECK(src.shape() == dst.shape());
96-
mxnet_op::copy(stream, dst.data(), src.data());
97-
}
98-
}
99-
});
100-
}
58+
const std::vector<NDArray> &outputs);
10159

10260
template<typename xpu>
10361
inline void BooleanMaskBackward(const nnvm::NodeAttrs& attrs,
10462
const OpContext &ctx,
10563
const std::vector<NDArray> &inputs,
10664
const std::vector<OpReqType> &req,
107-
const std::vector<NDArray> &outputs) {
108-
CHECK_EQ(inputs.size(), 3U);
109-
CHECK_EQ(outputs.size(), 2U);
110-
// inputs: {ograd, data, idx}
111-
// outputs: {igrad_data, igrad_idx}
112-
const NDArray& ograd = inputs[0];
113-
const NDArray& idx = inputs[2];
114-
const NDArray& igrad_data = outputs[0];
115-
MSHADOW_TYPE_SWITCH(idx.dtype(), DType, {
116-
DType* idx_dptr = idx.data().dptr<DType>();
117-
int length = idx.shape()[0];
118-
mshadow::Stream<xpu> *stream = ctx.get_stream<xpu>();
119-
Fill<false>(stream, igrad_data.data(), req[0], 0);
120-
for (int i = 0, j = 0; i < length; ++i) {
121-
if (idx_dptr[i]) {
122-
NDArray src = ograd.At(j++);
123-
NDArray dst = igrad_data.At(i);
124-
CHECK(src.shape() == dst.shape());
125-
mxnet_op::copy(stream, dst.data(), src.data());
126-
}
127-
}
128-
});
129-
}
65+
const std::vector<NDArray> &outputs);
13066

13167
} // namespace op
13268
} // namespace mxnet

src/operator/contrib/boolean_mask.cc

Lines changed: 82 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ namespace op {
2828

2929
DMLC_REGISTER_PARAMETER(BooleanMaskParam);
3030

31-
3231
bool BooleanMaskType(const nnvm::NodeAttrs& attrs,
3332
std::vector<int> *in_attrs,
3433
std::vector<int> *out_attrs) {
@@ -75,9 +74,86 @@ bool BooleanMaskBackStorageType(const nnvm::NodeAttrs& attrs,
7574
return true;
7675
}
7776

77+
template<>
78+
inline void BooleanMaskForward<cpu>(const nnvm::NodeAttrs& attrs,
79+
const OpContext &ctx,
80+
const std::vector<NDArray> &inputs,
81+
const std::vector<OpReqType> &req,
82+
const std::vector<NDArray> &outputs) {
83+
// TODO(@junrushao1994): This implementation is a proof-of-concept,
84+
// hence very slow actually. Performance should be improved in the future.
85+
CHECK_EQ(inputs.size(), 2U);
86+
CHECK_EQ(outputs.size(), 1U);
87+
const BooleanMaskParam& param = nnvm::get<BooleanMaskParam>(attrs.parsed);
88+
const int axis = param.axis;
89+
const NDArray &data = inputs[0];
90+
const NDArray &idx = inputs[1];
91+
const NDArray &out = outputs[0];
92+
CHECK_EQ(axis, 0) << "Not supported yet";
93+
CHECK_EQ(data.shape()[axis], idx.shape()[0]);
94+
CHECK_EQ(idx.shape().ndim(), 1U);
95+
// count the number of 1s in `idx`, so that we could know the output dimension
96+
size_t valid_num = 0;
97+
MSHADOW_TYPE_SWITCH(idx.dtype(), DType, {
98+
DType* idx_dptr = idx.data().dptr<DType>();
99+
int length = idx.shape()[0];
100+
for (int i = 0; i < length; i++) {
101+
if (idx_dptr[i]) {
102+
++valid_num;
103+
}
104+
}
105+
});
106+
// set the output shape forcefully
107+
TShape s = data.shape();
108+
s[axis] = valid_num;
109+
const_cast<NDArray &>(out).Init(s);
110+
// do the copy
111+
MSHADOW_TYPE_SWITCH(idx.dtype(), DType, {
112+
DType* idx_dptr = idx.data().dptr<DType>();
113+
int length = idx.shape()[0];
114+
mshadow::Stream<cpu> *stream = ctx.get_stream<cpu>();
115+
for (int i = 0, j = 0; i < length; ++i) {
116+
if (idx_dptr[i]) {
117+
NDArray src = data.At(i);
118+
NDArray dst = out.At(j++);
119+
CHECK(src.shape() == dst.shape());
120+
mxnet_op::copy(stream, dst.data(), src.data());
121+
}
122+
}
123+
});
124+
}
125+
126+
template<>
127+
inline void BooleanMaskBackward<cpu>(const nnvm::NodeAttrs& attrs,
128+
const OpContext &ctx,
129+
const std::vector<NDArray> &inputs,
130+
const std::vector<OpReqType> &req,
131+
const std::vector<NDArray> &outputs) {
132+
CHECK_EQ(inputs.size(), 3U);
133+
CHECK_EQ(outputs.size(), 2U);
134+
// inputs: {ograd, data, idx}
135+
// outputs: {igrad_data, igrad_idx}
136+
const NDArray& ograd = inputs[0];
137+
const NDArray& idx = inputs[2];
138+
const NDArray& igrad_data = outputs[0];
139+
MSHADOW_TYPE_SWITCH(idx.dtype(), DType, {
140+
DType* idx_dptr = idx.data().dptr<DType>();
141+
int length = idx.shape()[0];
142+
mshadow::Stream<cpu> *stream = ctx.get_stream<cpu>();
143+
Fill<false>(stream, igrad_data.data(), req[0], 0);
144+
for (int i = 0, j = 0; i < length; ++i) {
145+
if (idx_dptr[i]) {
146+
NDArray src = ograd.At(j++);
147+
NDArray dst = igrad_data.At(i);
148+
CHECK(src.shape() == dst.shape());
149+
mxnet_op::copy(stream, dst.data(), src.data());
150+
}
151+
}
152+
});
153+
}
154+
78155
NNVM_REGISTER_OP(_contrib_boolean_mask)
79156
.describe(R"code(
80-
Experimental CPU-only support for boolean masking.
81157
Given an n-d NDArray data, and a 1-d NDArray index,
82158
the operator produces an un-predeterminable shaped n-d NDArray out,
83159
which stands for the rows in x where the corresonding element in index is non-zero.
@@ -94,6 +170,10 @@ which stands for the rows in x where the corresonding element in index is non-ze
94170
.set_attr_parser(ParamParser<BooleanMaskParam>)
95171
.set_num_inputs(2)
96172
.set_num_outputs(1)
173+
.set_attr<nnvm::FListInputNames>("FListInputNames",
174+
[](const NodeAttrs& attrs) {
175+
return std::vector<std::string>{"data", "index"};
176+
})
97177
.set_attr<nnvm::FInferType>("FInferType", BooleanMaskType)
98178
.set_attr<FComputeEx>("FComputeEx<cpu>", BooleanMaskForward<cpu>)
99179
.set_attr<FInferStorageType>("FInferStorageType", BooleanMaskStorageType)
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
/*!
20+
* Copyright (c) 2018 by Contributors
21+
* \file boolean_mask.cu
22+
*/
23+
24+
#include "./boolean_mask-inl.h"
25+
#include <cub/cub.cuh>
26+
27+
namespace mxnet {
28+
namespace op {
29+
30+
struct BooleanMaskForwardKernel {
31+
template<typename DType>
32+
static void MSHADOW_XINLINE Map(int i,
33+
DType* out,
34+
const DType* data,
35+
const int32_t* idx,
36+
const size_t col_size) {
37+
int row_id = i / col_size;
38+
int col_id = i % col_size;
39+
int32_t prev = (row_id == 0) ? 0 : idx[row_id - 1];
40+
int32_t curr = idx[row_id];
41+
if (prev != curr) {
42+
out[prev * col_size + col_id] = data[i];
43+
}
44+
}
45+
};
46+
47+
struct BooleanMaskBackwardKernel {
48+
template<typename DType>
49+
static void MSHADOW_XINLINE Map(int i,
50+
DType* igrad,
51+
const DType* ograd,
52+
const int32_t* idx,
53+
const size_t col_size) {
54+
int row_id = i / col_size;
55+
int col_id = i % col_size;
56+
int32_t prev = (row_id == 0) ? 0 : idx[row_id - 1];
57+
int32_t curr = idx[row_id];
58+
if (prev != curr) {
59+
igrad[i] = ograd[prev * col_size + col_id];
60+
}
61+
}
62+
};
63+
64+
template<>
65+
inline void BooleanMaskForward<gpu>(const nnvm::NodeAttrs& attrs,
66+
const OpContext &ctx,
67+
const std::vector<NDArray> &inputs,
68+
const std::vector<OpReqType> &req,
69+
const std::vector<NDArray> &outputs) {
70+
using namespace mshadow;
71+
CHECK_EQ(inputs.size(), 2U);
72+
CHECK_EQ(outputs.size(), 1U);
73+
const BooleanMaskParam& param = nnvm::get<BooleanMaskParam>(attrs.parsed);
74+
const int axis = param.axis;
75+
const NDArray &data = inputs[0];
76+
const NDArray &idx = inputs[1];
77+
const NDArray &out = outputs[0];
78+
CHECK_EQ(axis, 0) << "Not supported yet";
79+
CHECK_EQ(data.shape()[axis], idx.shape()[0]);
80+
CHECK_EQ(idx.shape().ndim(), 1U);
81+
Stream<gpu>* s = ctx.get_stream<gpu>();
82+
// count the number of 1s in `idx`, so that we could know the output dimension
83+
size_t idx_size = idx.shape()[0];
84+
int32_t valid_num = 0;
85+
int32_t* prefix_sum = nullptr;
86+
void* d_temp_storage = nullptr;
87+
size_t temp_storage_bytes = 0;
88+
cub::DeviceScan::InclusiveSum(d_temp_storage,
89+
temp_storage_bytes,
90+
prefix_sum,
91+
prefix_sum,
92+
idx_size,
93+
Stream<gpu>::GetStream(s));
94+
size_t buffer_size = idx_size * sizeof(int32_t);
95+
temp_storage_bytes += buffer_size;
96+
Tensor<gpu, 1, char> workspace =
97+
ctx.requested[0].get_space_typed<gpu, 1, char>(Shape1(temp_storage_bytes), s);
98+
prefix_sum = reinterpret_cast<int32_t*>(workspace.dptr_);
99+
d_temp_storage = workspace.dptr_ + buffer_size;
100+
MSHADOW_TYPE_SWITCH(idx.dtype(), IType, {
101+
mxnet_op::Kernel<mshadow_op::identity_with_cast, gpu>::Launch(
102+
s, idx.shape()[0], prefix_sum, idx.data().dptr<IType>());
103+
});
104+
cub::DeviceScan::InclusiveSum(d_temp_storage,
105+
temp_storage_bytes,
106+
prefix_sum,
107+
prefix_sum,
108+
idx_size,
109+
Stream<gpu>::GetStream(s));
110+
CUDA_CALL(cudaMemcpy(&valid_num, &prefix_sum[idx_size - 1], sizeof(int32_t),
111+
cudaMemcpyDeviceToHost));
112+
CHECK(valid_num > 0) << "boolean_mask behavior not defined when all masks are 0";
113+
// Set the output shape forcefully
114+
TShape data_shape = data.shape();
115+
data_shape[axis] = valid_num;
116+
const_cast<NDArray &>(out).Init(data_shape);
117+
size_t input_size = data.shape().Size();
118+
size_t col_size = input_size / idx.shape()[0];
119+
// do the copy
120+
MSHADOW_TYPE_SWITCH(out.dtype(), DType, {
121+
mxnet_op::Kernel<BooleanMaskForwardKernel, gpu>::Launch(
122+
s, input_size, out.data().dptr<DType>(), data.data().dptr<DType>(), prefix_sum, col_size);
123+
});
124+
}
125+
126+
template<>
127+
inline void BooleanMaskBackward<gpu>(const nnvm::NodeAttrs& attrs,
128+
const OpContext &ctx,
129+
const std::vector<NDArray> &inputs,
130+
const std::vector<OpReqType> &req,
131+
const std::vector<NDArray> &outputs) {
132+
using namespace mshadow;
133+
CHECK_EQ(inputs.size(), 3U);
134+
CHECK_EQ(outputs.size(), 2U);
135+
// inputs: {ograd, data, idx}
136+
// outputs: {igrad_data, igrad_idx}
137+
const NDArray& ograd = inputs[0];
138+
const NDArray& idx = inputs[2];
139+
const NDArray& igrad_data = outputs[0];
140+
Stream<gpu>* s = ctx.get_stream<gpu>();
141+
// count the number of 1s in `idx`, so that we could know the output dimension
142+
size_t idx_size = idx.shape()[0];
143+
int32_t* prefix_sum = nullptr;
144+
void* d_temp_storage = nullptr;
145+
size_t temp_storage_bytes = 0;
146+
cub::DeviceScan::InclusiveSum(d_temp_storage,
147+
temp_storage_bytes,
148+
prefix_sum,
149+
prefix_sum,
150+
idx_size,
151+
Stream<gpu>::GetStream(s));
152+
size_t buffer_size = idx_size * sizeof(int32_t);
153+
temp_storage_bytes += buffer_size;
154+
Tensor<gpu, 1, char> workspace =
155+
ctx.requested[0].get_space_typed<gpu, 1, char>(Shape1(temp_storage_bytes), s);
156+
prefix_sum = reinterpret_cast<int32_t*>(workspace.dptr_);
157+
d_temp_storage = workspace.dptr_ + buffer_size;
158+
MSHADOW_TYPE_SWITCH(idx.dtype(), IType, {
159+
mxnet_op::Kernel<mshadow_op::identity_with_cast, gpu>::Launch(
160+
s, idx.shape()[0], prefix_sum, idx.data().dptr<IType>());
161+
});
162+
cub::DeviceScan::InclusiveSum(d_temp_storage,
163+
temp_storage_bytes,
164+
prefix_sum,
165+
prefix_sum,
166+
idx_size,
167+
Stream<gpu>::GetStream(s));
168+
size_t input_size = igrad_data.shape().Size();
169+
size_t col_size = input_size / idx_size;
170+
MSHADOW_TYPE_SWITCH(igrad_data.dtype(), DType, {
171+
mxnet_op::Kernel<BooleanMaskBackwardKernel, gpu>::Launch(
172+
s, input_size, igrad_data.data().dptr<DType>(), ograd.data().dptr<DType>(), prefix_sum, col_size);
173+
});
174+
}
175+
176+
NNVM_REGISTER_OP(_contrib_boolean_mask)
177+
.set_attr<FResourceRequest>("FResourceRequest",
178+
[](const NodeAttrs& attrs) {
179+
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
180+
})
181+
.set_attr<FComputeEx>("FComputeEx<gpu>", BooleanMaskForward<gpu>);
182+
183+
NNVM_REGISTER_OP(_backward_contrib_boolean_mask)
184+
.set_attr<FResourceRequest>("FResourceRequest",
185+
[](const NodeAttrs& attrs) {
186+
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
187+
})
188+
.set_attr<FComputeEx>("FComputeEx<gpu>", BooleanMaskBackward<gpu>);
189+
190+
} // namespace op
191+
} // namespace mxnet

tests/python/unittest/test_operator.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4861,8 +4861,6 @@ def test_index_copy():
48614861

48624862
@with_seed()
48634863
def test_boolean_mask():
4864-
if default_context().device_type != 'cpu':
4865-
return
48664864
data = mx.nd.array([[1, 2, 3],[4, 5, 6],[7, 8, 9]])
48674865
index = mx.nd.array([0, 1, 0])
48684866
data.attach_grad()

0 commit comments

Comments
 (0)