Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 2616275

Browse files
huangzhiyuanTaoLv
authored andcommitted
Add mkldnn OP for slice (#13730)
* add mkldnn slice * fix lint * fix lint * mv SliceEx to matrix_op.cc * fix lint * optimize dispatch_mode * retrigger ci * fix indent
1 parent 5b011b3 commit 2616275

File tree

5 files changed

+292
-16
lines changed

5 files changed

+292
-16
lines changed
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file mkldnn_slice-inl.h
22+
* \brief
23+
* \author Zhiyuan Huang
24+
*/
25+
26+
#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_SLICE_INL_H_
27+
#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_SLICE_INL_H_
28+
29+
#if MXNET_USE_MKLDNN == 1
30+
31+
#include <dmlc/logging.h>
32+
#include <dmlc/parameter.h>
33+
#include <mxnet/operator.h>
34+
#include <utility>
35+
#include "../../operator_common.h"
36+
#include "../../tensor/slice-inl.h"
37+
#include "./mkldnn_base-inl.h"
38+
39+
namespace mxnet {
40+
namespace op {
41+
42+
class MKLDNNSliceFwd {
43+
public:
44+
MKLDNNSliceFwd(const SliceParam &param,
45+
const NDArray &in,
46+
const NDArray &out);
47+
void SetNewMem(const mkldnn::memory &input, const mkldnn::memory &output);
48+
const mkldnn::reorder &GetPd() const;
49+
50+
private:
51+
std::shared_ptr<mkldnn::memory> data_;
52+
std::shared_ptr<mkldnn::memory> out_;
53+
std::shared_ptr<mkldnn::reorder> fwd_;
54+
};
55+
56+
typedef ParamOpSign<SliceParam> MKLDNNSliceSignature;
57+
MKLDNNSliceFwd &GetSliceForward(const SliceParam &param, const bool is_train,
58+
const NDArray &in_data, const NDArray &out_data);
59+
60+
void MKLDNNSlice(const SliceParam &param, const OpContext& ctx,
61+
const NDArray &in, OpReqType req, const NDArray &out);
62+
63+
} // namespace op
64+
} // namespace mxnet
65+
#endif // MXNET_USE_MKLDNN == 1
66+
#endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_SLICE_INL_H_
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file mkldnn_slice.cc
22+
* \brief
23+
* \author Zhiyuan Huang
24+
*/
25+
26+
#if MXNET_USE_MKLDNN == 1
27+
28+
#include "./mkldnn_ops-inl.h"
29+
#include "./mkldnn_base-inl.h"
30+
#include "./mkldnn_slice-inl.h"
31+
32+
namespace mxnet {
33+
namespace op {
34+
35+
MKLDNNSliceFwd::MKLDNNSliceFwd(const SliceParam &param,
36+
const NDArray &in,
37+
const NDArray &out) {
38+
const TShape ishape = in.shape();
39+
const TShape oshape = out.shape();
40+
uint32_t N = ishape.ndim();
41+
mkldnn::memory::dims dims(N);
42+
mkldnn::memory::dims offsets(N);
43+
for (uint32_t i = 0; i < N; ++i) {
44+
int s = 0;
45+
if (param.begin[i]) {
46+
s = *param.begin[i];
47+
if (s < 0) s += ishape[i];
48+
}
49+
dims[i] = oshape[i];
50+
offsets[i] = s;
51+
}
52+
auto in_mem_pd = in.GetMKLDNNData()->get_primitive_desc();
53+
auto out_mem_pd = out.GetMKLDNNData()->get_primitive_desc();
54+
auto view_pd = mkldnn::view::primitive_desc(in_mem_pd, dims, offsets);
55+
auto reorder_pd = reorder::primitive_desc(view_pd.dst_primitive_desc(), out_mem_pd);
56+
this->data_ = std::make_shared<mkldnn::memory>(view_pd.dst_primitive_desc(), nullptr);
57+
this->out_ = std::make_shared<mkldnn::memory>(view_pd.dst_primitive_desc(), nullptr);
58+
this->fwd_ = std::make_shared<mkldnn::reorder>(reorder_pd, *this->data_, *this->out_);
59+
}
60+
61+
void MKLDNNSliceFwd::SetNewMem(const mkldnn::memory &input, const mkldnn::memory &output) {
62+
this->data_->set_data_handle(input.get_data_handle());
63+
this->out_->set_data_handle(output.get_data_handle());
64+
}
65+
66+
const mkldnn::reorder &MKLDNNSliceFwd::GetPd() const {
67+
return *fwd_;
68+
}
69+
70+
MKLDNNSliceFwd &GetSliceForward(const SliceParam &param, const bool is_train,
71+
const NDArray &in_data, const NDArray &out_data) {
72+
#if DMLC_CXX11_THREAD_LOCAL
73+
static thread_local std::unordered_map<MKLDNNSliceSignature, MKLDNNSliceFwd, OpHash> fwds;
74+
#else
75+
static MX_THREAD_LOCAL std::unordered_map<MKLDNNSliceSignature, MKLDNNSliceFwd, OpHash> fwds;
76+
#endif
77+
MKLDNNSliceSignature key(param);
78+
key.AddSign(is_train);
79+
key.AddSign(in_data);
80+
key.AddSign(out_data);
81+
82+
auto it = fwds.find(key);
83+
if (it == fwds.end()) {
84+
MKLDNNSliceFwd fwd(param, in_data, out_data);
85+
it = AddToCache(&fwds, key, fwd);
86+
}
87+
return it->second;
88+
}
89+
90+
void MKLDNNSlice(const SliceParam &param, const OpContext& ctx,
91+
const NDArray &in, OpReqType req, const NDArray &out) {
92+
MKLDNNSliceFwd &fwd = GetSliceForward(param, ctx.is_train, in, out);
93+
auto in_mem = in.GetMKLDNNData();
94+
auto out_mem_pd = out.GetMKLDNNData()->get_primitive_desc();
95+
auto out_mem = CreateMKLDNNMem(out, out_mem_pd, req);
96+
fwd.SetNewMem(*in_mem, *out_mem.second);
97+
MKLDNNStream::Get()->RegisterPrim(fwd.GetPd());
98+
CommitOutput(out, out_mem);
99+
MKLDNNStream::Get()->Submit();
100+
}
101+
102+
} // namespace op
103+
} // namespace mxnet
104+
#endif // MXNET_USE_MKLDNN == 1

src/operator/tensor/matrix_op-inl.h

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
#include "broadcast_reduce_op.h"
3838
#include "./init_op.h"
3939
#include "../../common/static_array.h"
40+
#include "./slice-inl.h"
4041

4142
#if MXNET_USE_CUDA
4243
#include <thrust/device_vector.h>
@@ -398,19 +399,15 @@ inline bool ExpandDimShape(const nnvm::NodeAttrs& attrs,
398399
return true;
399400
}
400401

401-
struct SliceParam : public dmlc::Parameter<SliceParam> {
402-
nnvm::Tuple<dmlc::optional<int>> begin, end;
403-
nnvm::Tuple<dmlc::optional<int>> step;
404-
DMLC_DECLARE_PARAMETER(SliceParam) {
405-
DMLC_DECLARE_FIELD(begin)
406-
.describe("starting indices for the slice operation, supports negative indices.");
407-
DMLC_DECLARE_FIELD(end)
408-
.describe("ending indices for the slice operation, supports negative indices.");
409-
DMLC_DECLARE_FIELD(step)
410-
.set_default(nnvm::Tuple<dmlc::optional<int>>())
411-
.describe("step for the slice operation, supports negative values.");
402+
// Currently MKLDNN only supports step = 1 or step has no value
403+
inline bool SupportMKLDNNSlice(const SliceParam& param) {
404+
if (param.step.ndim() == 0U) return true;
405+
for (uint32_t i = 0; i < param.step.ndim(); ++i) {
406+
if (param.step[i].has_value() && param.step[i].value() != 1)
407+
return false;
412408
}
413-
};
409+
return true;
410+
}
414411

415412
inline bool SliceForwardInferStorageType(const nnvm::NodeAttrs& attrs,
416413
const int dev_mask,
@@ -432,9 +429,19 @@ inline bool SliceForwardInferStorageType(const nnvm::NodeAttrs& attrs,
432429
&& (!param.step[0].has_value() || param.step[0].value() == 1)) {
433430
trivial_step = true;
434431
}
435-
if (!dispatched && in_stype == kDefaultStorage) {
436-
dispatched = storage_type_assign(&out_stype, kDefaultStorage,
437-
dispatch_mode, DispatchMode::kFCompute);
432+
433+
if (in_stype == kDefaultStorage) {
434+
#if MXNET_USE_MKLDNN == 1
435+
if (dev_mask == Context::kCPU && MKLDNNEnvSet()
436+
&& SupportMKLDNNSlice(param)) {
437+
dispatched = storage_type_assign(&out_stype, kDefaultStorage,
438+
dispatch_mode, dispatch_ex);
439+
}
440+
#endif
441+
if (!dispatched) {
442+
dispatched = storage_type_assign(&out_stype, kDefaultStorage,
443+
dispatch_mode, DispatchMode::kFCompute);
444+
}
438445
}
439446

440447
if (!dispatched && in_stype == kCSRStorage && trivial_step) {

src/operator/tensor/matrix_op.cc

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "./elemwise_unary_op.h"
2828
#include "../nn/mkldnn/mkldnn_ops-inl.h"
2929
#include "../nn/mkldnn/mkldnn_base-inl.h"
30+
#include "../nn/mkldnn/mkldnn_slice-inl.h"
3031

3132
namespace mxnet {
3233
namespace op {
@@ -420,6 +421,30 @@ will return a new array with shape ``(2,1,3,4)``.
420421
.add_argument("data", "NDArray-or-Symbol", "Source input")
421422
.add_arguments(ExpandDimParam::__FIELDS__());
422423

424+
void SliceExCPU(const nnvm::NodeAttrs& attrs,
425+
const OpContext& ctx,
426+
const std::vector<NDArray>& inputs,
427+
const std::vector<OpReqType>& req,
428+
const std::vector<NDArray>& outputs) {
429+
CHECK_EQ(inputs.size(), 1);
430+
CHECK_EQ(outputs.size(), 1);
431+
const SliceParam& param = nnvm::get<SliceParam>(attrs.parsed);
432+
auto in_stype = inputs[0].storage_type();
433+
if (in_stype == kCSRStorage) {
434+
SliceCsrImpl<cpu>(param, ctx, inputs[0], req[0], outputs[0]);
435+
#if MXNET_USE_MKLDNN == 1
436+
} else if (in_stype == kDefaultStorage) {
437+
if (SupportMKLDNN(inputs[0])) {
438+
MKLDNNSlice(param, ctx, inputs[0], req[0], outputs[0]);
439+
} else {
440+
FallBackCompute(SliceOpForward<cpu>, attrs, ctx, inputs, req, outputs);
441+
}
442+
#endif
443+
} else {
444+
LOG(FATAL) << "Slice not implemented for storage type" << in_stype;
445+
}
446+
}
447+
423448
NNVM_REGISTER_OP(slice)
424449
MXNET_ADD_SPARSE_OP_ALIAS(slice)
425450
.add_alias("crop")
@@ -478,7 +503,10 @@ Example::
478503
.set_attr<FInferStorageType>("FInferStorageType", SliceForwardInferStorageType)
479504
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_slice"})
480505
.set_attr<FCompute>("FCompute<cpu>", SliceOpForward<cpu>)
481-
.set_attr<FComputeEx>("FComputeEx<cpu>", SliceEx<cpu>)
506+
.set_attr<FComputeEx>("FComputeEx<cpu>", SliceExCPU)
507+
#if MXNET_USE_MKLDNN == 1
508+
.set_attr<bool>("TIsMKLDNN", true)
509+
#endif
482510
.add_argument("data", "NDArray-or-Symbol", "Source input")
483511
.add_arguments(SliceParam::__FIELDS__());
484512

src/operator/tensor/slice-inl.h

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file slice-inl.h
22+
* \brief
23+
* \author Zhiyuan Huang
24+
*/
25+
26+
#ifndef MXNET_OPERATOR_TENSOR_SLICE_INL_H_
27+
#define MXNET_OPERATOR_TENSOR_SLICE_INL_H_
28+
29+
#include <utility>
30+
#include <vector>
31+
#include <string>
32+
33+
namespace mxnet {
34+
namespace op {
35+
36+
struct SliceParam : public dmlc::Parameter<SliceParam> {
37+
nnvm::Tuple<dmlc::optional<int>> begin, end;
38+
nnvm::Tuple<dmlc::optional<int>> step;
39+
DMLC_DECLARE_PARAMETER(SliceParam) {
40+
DMLC_DECLARE_FIELD(begin)
41+
.describe("starting indices for the slice operation, supports negative indices.");
42+
DMLC_DECLARE_FIELD(end)
43+
.describe("ending indices for the slice operation, supports negative indices.");
44+
DMLC_DECLARE_FIELD(step)
45+
.set_default(nnvm::Tuple<dmlc::optional<int>>())
46+
.describe("step for the slice operation, supports negative values.");
47+
}
48+
bool operator==(const SliceParam& other) const {
49+
return this->begin == other.begin &&
50+
this->end == other.end &&
51+
this->step == other.step;
52+
}
53+
};
54+
55+
} // namespace op
56+
} // namespace mxnet
57+
58+
namespace std {
59+
template<>
60+
struct hash<mxnet::op::SliceParam> {
61+
size_t operator()(const mxnet::op::SliceParam& val) {
62+
size_t ret = 0;
63+
ret = dmlc::HashCombine(ret, val.begin);
64+
ret = dmlc::HashCombine(ret, val.end);
65+
ret = dmlc::HashCombine(ret, val.step);
66+
return ret;
67+
}
68+
};
69+
} // namespace std
70+
71+
#endif // MXNET_OPERATOR_TENSOR_SLICE_INL_H_

0 commit comments

Comments
 (0)