|
| 1 | +/* |
| 2 | + * Licensed to the Apache Software Foundation (ASF) under one |
| 3 | + * or more contributor license agreements. See the NOTICE file |
| 4 | + * distributed with this work for additional information |
| 5 | + * regarding copyright ownership. The ASF licenses this file |
| 6 | + * to you under the Apache License, Version 2.0 (the |
| 7 | + * "License"); you may not use this file except in compliance |
| 8 | + * with the License. You may obtain a copy of the License at |
| 9 | + * |
| 10 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 11 | + * |
| 12 | + * Unless required by applicable law or agreed to in writing, |
| 13 | + * software distributed under the License is distributed on an |
| 14 | + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 15 | + * KIND, either express or implied. See the License for the |
| 16 | + * specific language governing permissions and limitations |
| 17 | + * under the License. |
| 18 | + */ |
| 19 | +/*! |
| 20 | + * Copyright (c) 2018 by Contributors |
| 21 | + * \file boolean_mask.cu |
| 22 | +*/ |
| 23 | + |
| 24 | +#include "./boolean_mask-inl.h" |
| 25 | +#include <cub/cub.cuh> |
| 26 | + |
| 27 | +namespace mxnet { |
| 28 | +namespace op { |
| 29 | + |
| 30 | +struct BooleanMaskForwardKernel { |
| 31 | + template<typename DType> |
| 32 | + static void MSHADOW_XINLINE Map(int i, |
| 33 | + DType* out, |
| 34 | + const DType* data, |
| 35 | + const int32_t* idx, |
| 36 | + const size_t col_size) { |
| 37 | + int row_id = i / col_size; |
| 38 | + int col_id = i % col_size; |
| 39 | + int32_t prev = (row_id == 0) ? 0 : idx[row_id - 1]; |
| 40 | + int32_t curr = idx[row_id]; |
| 41 | + if (prev != curr) { |
| 42 | + out[prev * col_size + col_id] = data[i]; |
| 43 | + } |
| 44 | + } |
| 45 | +}; |
| 46 | + |
| 47 | +struct BooleanMaskBackwardKernel { |
| 48 | + template<typename DType> |
| 49 | + static void MSHADOW_XINLINE Map(int i, |
| 50 | + DType* igrad, |
| 51 | + const DType* ograd, |
| 52 | + const int32_t* idx, |
| 53 | + const size_t col_size) { |
| 54 | + int row_id = i / col_size; |
| 55 | + int col_id = i % col_size; |
| 56 | + int32_t prev = (row_id == 0) ? 0 : idx[row_id - 1]; |
| 57 | + int32_t curr = idx[row_id]; |
| 58 | + if (prev != curr) { |
| 59 | + igrad[i] = ograd[prev * col_size + col_id]; |
| 60 | + } |
| 61 | + } |
| 62 | +}; |
| 63 | + |
| 64 | +template<> |
| 65 | +inline void BooleanMaskForward<gpu>(const nnvm::NodeAttrs& attrs, |
| 66 | + const OpContext &ctx, |
| 67 | + const std::vector<NDArray> &inputs, |
| 68 | + const std::vector<OpReqType> &req, |
| 69 | + const std::vector<NDArray> &outputs) { |
| 70 | + using namespace mshadow; |
| 71 | + CHECK_EQ(inputs.size(), 2U); |
| 72 | + CHECK_EQ(outputs.size(), 1U); |
| 73 | + const BooleanMaskParam& param = nnvm::get<BooleanMaskParam>(attrs.parsed); |
| 74 | + const int axis = param.axis; |
| 75 | + const NDArray &data = inputs[0]; |
| 76 | + const NDArray &idx = inputs[1]; |
| 77 | + const NDArray &out = outputs[0]; |
| 78 | + CHECK_EQ(axis, 0) << "Not supported yet"; |
| 79 | + CHECK_EQ(data.shape()[axis], idx.shape()[0]); |
| 80 | + CHECK_EQ(idx.shape().ndim(), 1U); |
| 81 | + Stream<gpu>* s = ctx.get_stream<gpu>(); |
| 82 | + // count the number of 1s in `idx`, so that we could know the output dimension |
| 83 | + size_t idx_size = idx.shape()[0]; |
| 84 | + int32_t valid_num = 0; |
| 85 | + int32_t* prefix_sum = nullptr; |
| 86 | + void* d_temp_storage = nullptr; |
| 87 | + size_t temp_storage_bytes = 0; |
| 88 | + cub::DeviceScan::InclusiveSum(d_temp_storage, |
| 89 | + temp_storage_bytes, |
| 90 | + prefix_sum, |
| 91 | + prefix_sum, |
| 92 | + idx_size, |
| 93 | + Stream<gpu>::GetStream(s)); |
| 94 | + size_t buffer_size = idx_size * sizeof(int32_t); |
| 95 | + temp_storage_bytes += buffer_size; |
| 96 | + Tensor<gpu, 1, char> workspace = |
| 97 | + ctx.requested[0].get_space_typed<gpu, 1, char>(Shape1(temp_storage_bytes), s); |
| 98 | + prefix_sum = reinterpret_cast<int32_t*>(workspace.dptr_); |
| 99 | + d_temp_storage = workspace.dptr_ + buffer_size; |
| 100 | + MSHADOW_TYPE_SWITCH(idx.dtype(), IType, { |
| 101 | + mxnet_op::Kernel<mshadow_op::identity_with_cast, gpu>::Launch( |
| 102 | + s, idx.shape()[0], prefix_sum, idx.data().dptr<IType>()); |
| 103 | + }); |
| 104 | + cub::DeviceScan::InclusiveSum(d_temp_storage, |
| 105 | + temp_storage_bytes, |
| 106 | + prefix_sum, |
| 107 | + prefix_sum, |
| 108 | + idx_size, |
| 109 | + Stream<gpu>::GetStream(s)); |
| 110 | + CUDA_CALL(cudaMemcpy(&valid_num, &prefix_sum[idx_size - 1], sizeof(int32_t), |
| 111 | + cudaMemcpyDeviceToHost)); |
| 112 | + CHECK(valid_num > 0) << "boolean_mask behavior not defined when all masks are 0"; |
| 113 | + // Set the output shape forcefully |
| 114 | + TShape data_shape = data.shape(); |
| 115 | + data_shape[axis] = valid_num; |
| 116 | + const_cast<NDArray &>(out).Init(data_shape); |
| 117 | + size_t input_size = data.shape().Size(); |
| 118 | + size_t col_size = input_size / idx.shape()[0]; |
| 119 | + // do the copy |
| 120 | + MSHADOW_TYPE_SWITCH(out.dtype(), DType, { |
| 121 | + mxnet_op::Kernel<BooleanMaskForwardKernel, gpu>::Launch( |
| 122 | + s, input_size, out.data().dptr<DType>(), data.data().dptr<DType>(), prefix_sum, col_size); |
| 123 | + }); |
| 124 | +} |
| 125 | + |
| 126 | +template<> |
| 127 | +inline void BooleanMaskBackward<gpu>(const nnvm::NodeAttrs& attrs, |
| 128 | + const OpContext &ctx, |
| 129 | + const std::vector<NDArray> &inputs, |
| 130 | + const std::vector<OpReqType> &req, |
| 131 | + const std::vector<NDArray> &outputs) { |
| 132 | + using namespace mshadow; |
| 133 | + CHECK_EQ(inputs.size(), 3U); |
| 134 | + CHECK_EQ(outputs.size(), 2U); |
| 135 | + // inputs: {ograd, data, idx} |
| 136 | + // outputs: {igrad_data, igrad_idx} |
| 137 | + const NDArray& ograd = inputs[0]; |
| 138 | + const NDArray& idx = inputs[2]; |
| 139 | + const NDArray& igrad_data = outputs[0]; |
| 140 | + Stream<gpu>* s = ctx.get_stream<gpu>(); |
| 141 | + // count the number of 1s in `idx`, so that we could know the output dimension |
| 142 | + size_t idx_size = idx.shape()[0]; |
| 143 | + int32_t* prefix_sum = nullptr; |
| 144 | + void* d_temp_storage = nullptr; |
| 145 | + size_t temp_storage_bytes = 0; |
| 146 | + cub::DeviceScan::InclusiveSum(d_temp_storage, |
| 147 | + temp_storage_bytes, |
| 148 | + prefix_sum, |
| 149 | + prefix_sum, |
| 150 | + idx_size, |
| 151 | + Stream<gpu>::GetStream(s)); |
| 152 | + size_t buffer_size = idx_size * sizeof(int32_t); |
| 153 | + temp_storage_bytes += buffer_size; |
| 154 | + Tensor<gpu, 1, char> workspace = |
| 155 | + ctx.requested[0].get_space_typed<gpu, 1, char>(Shape1(temp_storage_bytes), s); |
| 156 | + prefix_sum = reinterpret_cast<int32_t*>(workspace.dptr_); |
| 157 | + d_temp_storage = workspace.dptr_ + buffer_size; |
| 158 | + MSHADOW_TYPE_SWITCH(idx.dtype(), IType, { |
| 159 | + mxnet_op::Kernel<mshadow_op::identity_with_cast, gpu>::Launch( |
| 160 | + s, idx.shape()[0], prefix_sum, idx.data().dptr<IType>()); |
| 161 | + }); |
| 162 | + cub::DeviceScan::InclusiveSum(d_temp_storage, |
| 163 | + temp_storage_bytes, |
| 164 | + prefix_sum, |
| 165 | + prefix_sum, |
| 166 | + idx_size, |
| 167 | + Stream<gpu>::GetStream(s)); |
| 168 | + size_t input_size = igrad_data.shape().Size(); |
| 169 | + size_t col_size = input_size / idx_size; |
| 170 | + MSHADOW_TYPE_SWITCH(igrad_data.dtype(), DType, { |
| 171 | + mxnet_op::Kernel<BooleanMaskBackwardKernel, gpu>::Launch( |
| 172 | + s, input_size, igrad_data.data().dptr<DType>(), ograd.data().dptr<DType>(), prefix_sum, col_size); |
| 173 | + }); |
| 174 | +} |
| 175 | + |
| 176 | +NNVM_REGISTER_OP(_contrib_boolean_mask) |
| 177 | +.set_attr<FResourceRequest>("FResourceRequest", |
| 178 | + [](const NodeAttrs& attrs) { |
| 179 | + return std::vector<ResourceRequest>{ResourceRequest::kTempSpace}; |
| 180 | + }) |
| 181 | +.set_attr<FComputeEx>("FComputeEx<gpu>", BooleanMaskForward<gpu>); |
| 182 | + |
| 183 | +NNVM_REGISTER_OP(_backward_contrib_boolean_mask) |
| 184 | +.set_attr<FResourceRequest>("FResourceRequest", |
| 185 | + [](const NodeAttrs& attrs) { |
| 186 | + return std::vector<ResourceRequest>{ResourceRequest::kTempSpace}; |
| 187 | + }) |
| 188 | +.set_attr<FComputeEx>("FComputeEx<gpu>", BooleanMaskBackward<gpu>); |
| 189 | + |
| 190 | +} // namespace op |
| 191 | +} // namespace mxnet |
0 commit comments