Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit ddb0c09

Browse files
committed
Add primitive cache for mkldnn sum
1 parent 5ba285b commit ddb0c09

File tree

2 files changed

+93
-19
lines changed

2 files changed

+93
-19
lines changed

src/operator/nn/mkldnn/mkldnn_sum.cc

Lines changed: 86 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
*/
2525
#include <iostream>
2626

27+
#include "../../operator_common.h"
2728
#include "./mkldnn_ops-inl.h"
2829
#include "./mkldnn_base-inl.h"
2930

@@ -58,37 +59,104 @@ void MKLDNNSum(const mkldnn::memory &arr1, const mkldnn::memory &arr2,
5859
MKLDNNStream::Get()->RegisterPrim(mkldnn::sum(sum_pd, inputs, out));
5960
}
6061

62+
class MKLDNNSumFwd {
63+
public:
64+
mkldnn::sum::primitive_desc fwd_pd;
65+
66+
MKLDNNSumFwd(const std::vector<float> &scales,
67+
const std::vector<mkldnn::memory::primitive_desc> &data_md)
68+
: fwd_pd(scales, data_md) {
69+
data_.resize(data_md.size());
70+
}
71+
72+
void SetNewMem(const std::vector<const mkldnn::memory *> &in_data, const mkldnn::memory &output);
73+
74+
const mkldnn::sum &GetFwd() const { return *fwd_; }
75+
76+
private:
77+
std::shared_ptr<mkldnn::sum> fwd_;
78+
std::vector<std::shared_ptr<mkldnn::memory>> data_;
79+
std::vector<mkldnn::primitive::at> data_mem_;
80+
std::shared_ptr<mkldnn::memory> out_;
81+
};
82+
83+
static MKLDNNSumFwd &GetSumForward(
84+
const std::vector<float> &scales, const std::vector<NDArray> &in_data,
85+
const std::vector<mkldnn::memory::primitive_desc> &data_md) {
86+
#if DMLC_CXX11_THREAD_LOCAL
87+
static thread_local std::unordered_map<OpSignature, MKLDNNSumFwd, OpHash> fwds;
88+
#else
89+
static MX_THREAD_LOCAL std::unordered_map<OpSignature, MKLDNNSumFwd, OpHash> fwds;
90+
#endif
91+
OpSignature key;
92+
key.AddSign(in_data);
93+
94+
auto it = fwds.find(key);
95+
if (it == fwds.end()) {
96+
MKLDNNSumFwd fwd(scales, data_md);
97+
it = AddToCache(&fwds, key, fwd);
98+
}
99+
return it->second;
100+
}
101+
102+
void MKLDNNSumFwd::SetNewMem(const std::vector<const mkldnn::memory *> &in_data,
103+
const mkldnn::memory &output) {
104+
auto num_inputs = data_.size();
105+
CHECK_EQ(in_data.size(), num_inputs);
106+
for (index_t i = 0; i < static_cast<index_t>(num_inputs); ++i) {
107+
if (this->data_[i] == nullptr) {
108+
this->data_[i] = std::shared_ptr<mkldnn::memory>(
109+
new mkldnn::memory(in_data[i]->get_primitive_desc(), in_data[i]->get_data_handle()));
110+
this->data_mem_.push_back(*this->data_[i]);
111+
} else {
112+
this->data_[i]->set_data_handle(in_data[i]->get_data_handle());
113+
}
114+
}
115+
if (this->out_ == nullptr)
116+
this->out_ = std::shared_ptr<mkldnn::memory>(
117+
new mkldnn::memory(fwd_pd.dst_primitive_desc(), output.get_data_handle()));
118+
else
119+
this->out_->set_data_handle(output.get_data_handle());
120+
121+
if (this->fwd_ == nullptr)
122+
this->fwd_.reset(new mkldnn::sum(fwd_pd, this->data_mem_, *this->out_));
123+
}
124+
61125
void MKLDNNSumForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
62126
const std::vector<NDArray> &inputs, const OpReqType &req,
63127
const NDArray &out_data) {
64-
if (req == kNullOp) {
65-
return;
66-
}
67-
68128
TmpMemMgr::Get()->Init(ctx.requested[0]);
69-
std::vector<mkldnn::primitive::at> in_prims;
70-
std::vector<mkldnn::memory::primitive_desc> in_pds(inputs.size());
71-
std::vector<float> scales(inputs.size(), 1);
72-
in_prims.reserve(inputs.size());
73-
std::vector<NDArray> in_bufs(inputs.size());
74-
for (size_t i = 0; i < inputs.size(); i++) {
129+
auto num_inputs = inputs.size();
130+
std::vector<mkldnn::memory::primitive_desc> data_md;
131+
std::vector<const mkldnn::memory *> data_mem;
132+
std::vector<float> scales(num_inputs, 1);
133+
std::vector<NDArray> in_bufs(num_inputs);
134+
135+
data_md.reserve(num_inputs);
136+
data_mem.reserve(num_inputs);
137+
138+
for (index_t i = 0; i < static_cast<index_t>(num_inputs); ++i) {
75139
const mkldnn::memory *in_mem;
76140
if (inputs[i].IsMKLDNNData() && inputs[i].IsView()) {
77141
in_bufs[i] = inputs[i].Reorder2Default();
78142
in_mem = in_bufs[i].GetMKLDNNData();
79143
} else {
80144
in_mem = inputs[i].GetMKLDNNData();
81145
}
82-
in_prims.push_back(*in_mem);
83-
in_pds[i] = in_mem->get_primitive_desc();
146+
mkldnn::memory::primitive_desc tmp_pd = in_mem->get_primitive_desc();
147+
data_md.push_back(tmp_pd);
148+
data_mem.push_back(in_mem);
84149
}
85150

86-
mkldnn::sum::primitive_desc pdesc(scales, in_pds);
87-
auto mem = CreateMKLDNNMem(out_data, pdesc.dst_primitive_desc(), req, &inputs[0]);
88-
MKLDNNStream *stream = MKLDNNStream::Get();
89-
stream->RegisterPrim(mkldnn::sum(pdesc, in_prims, *mem.second));
90-
CommitOutput(out_data, mem);
91-
stream->Submit();
151+
MKLDNNSumFwd &fwd = GetSumForward(scales, inputs, data_md);
152+
mxnet::mkldnn_output_t out_mem = CreateMKLDNNMem(out_data,
153+
fwd.fwd_pd.dst_primitive_desc(),
154+
req,
155+
&inputs[0]);
156+
fwd.SetNewMem(data_mem, *out_mem.second);
157+
MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd());
158+
CommitOutput(out_data, out_mem);
159+
MKLDNNStream::Get()->Submit();
92160
}
93161

94162
} // namespace op

src/operator/tensor/elemwise_binary_op_basic.cc

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@
3030
namespace mxnet {
3131
namespace op {
3232

33+
bool SupportMKLDNNSum(const NDArray& input) {
34+
int ndim = input.shape().ndim();
35+
return input.dtype() == mshadow::kFloat32 && (ndim >= 1 && ndim <= 4) &&
36+
input.storage_type() == kDefaultStorage;
37+
}
38+
3339
static void ElemwiseAddEx(const nnvm::NodeAttrs& attrs,
3440
const OpContext& ctx,
3541
const std::vector<NDArray>& inputs,
@@ -38,7 +44,7 @@ static void ElemwiseAddEx(const nnvm::NodeAttrs& attrs,
3844
CHECK_EQ(inputs.size(), 2U);
3945
CHECK_EQ(outputs.size(), 1U);
4046
#if MXNET_USE_MKLDNN == 1
41-
if (SupportMKLDNN(inputs[0]) && SupportMKLDNN(inputs[1])) {
47+
if (SupportMKLDNNSum(inputs[0]) && SupportMKLDNNSum(inputs[1])) {
4248
MKLDNNSumForward(attrs, ctx, inputs, req[0], outputs[0]);
4349
return;
4450
} else if (inputs[0].storage_type() == kDefaultStorage

0 commit comments

Comments
 (0)