Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 458bb73

Browse files
rongzha1TaoLv
authored andcommitted
[mkldnn-v1.0] Add MKL-DNN Pooling (#16272)
* add mkldnn pooling * add workaround for mkldnn v1.0 pooling fwd && bwd workspace mismatch * code clean * fix lint error * trigger CI * trigger CI * add extra work_space check and fix some typo * trigger CI
1 parent 3706ece commit 458bb73

File tree

3 files changed

+100
-145
lines changed

3 files changed

+100
-145
lines changed

src/operator/nn/mkldnn/mkldnn_pooling-inl.h

Lines changed: 14 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_POOLING_INL_H_
2525
#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_POOLING_INL_H_
2626

27-
#if MXNET_USE_MKLDNN == 1
27+
#if MXNET_USE_MKLDNN == 100
2828

2929
#include <utility>
3030
#include <mkldnn.hpp>
@@ -43,60 +43,48 @@ class MKLDNNPoolingFwd {
4343
const int padding_t, const int padding_b,
4444
const int padding_l, const int padding_r,
4545
const mkldnn::algorithm alg_kind,
46-
const bool with_workspace, const bool is_train) :
47-
is_train_(is_train),
46+
const bool with_workspace, const bool is_train):
4847
with_workspace_(with_workspace),
49-
alg_kind_(alg_kind),
50-
fwd_(nullptr), data_(nullptr), out_(nullptr), workspace_(nullptr) {
48+
fwd_(nullptr) {
5149
Init(input, output,
5250
kernel_h, kernel_w, stride_h, stride_w,
53-
padding_t, padding_b, padding_l, padding_r);
51+
padding_t, padding_b, padding_l, padding_r,
52+
is_train, alg_kind);
5453
}
5554

5655
~MKLDNNPoolingFwd() {}
57-
void SetNewMem(const NDArray& in_data,
58-
const NDArray& out_data,
59-
const OpReqType& req,
60-
const mxnet::NDArray *workspace = nullptr);
61-
void Execute(const NDArray& out_data);
56+
void Execute(const NDArray &in_data,
57+
const OpReqType req,
58+
const NDArray& out_data,
59+
const NDArray *workspace);
6260

6361
private:
64-
bool is_train_;
6562
bool with_workspace_;
66-
mkldnn::algorithm alg_kind_;
63+
6764
std::shared_ptr<mkldnn::pooling_forward::primitive_desc> fwd_pd_;
6865
std::shared_ptr<mkldnn::pooling_forward> fwd_;
69-
std::shared_ptr<mkldnn::memory> data_;
70-
std::shared_ptr<mkldnn::memory> out_;
71-
std::shared_ptr<mkldnn::memory> workspace_;
72-
mkldnn_output_t output_mem_t_;
7366

7467
private:
7568
void Init(const mxnet::NDArray &input,
7669
const mxnet::NDArray &output,
7770
const int kernel_h, const int kernel_w,
7871
const int stride_h, const int stride_w,
7972
const int padding_t, const int padding_b,
80-
const int padding_l, const int padding_r);
73+
const int padding_l, const int padding_r,
74+
const bool is_train, const mkldnn::algorithm alg_kind);
8175
};
8276

8377
class MKLDNNPoolingBwd {
8478
std::shared_ptr<const mkldnn::pooling_backward> bwd;
85-
std::shared_ptr<mkldnn::memory> diff_dst;
86-
std::shared_ptr<mkldnn::memory> diff_src;
87-
std::shared_ptr<mkldnn::memory> ws;
8879
bool with_workspace;
8980

9081
public:
9182
const mkldnn::pooling_backward::primitive_desc pd;
9283

93-
MKLDNNPoolingBwd(const pooling_backward::primitive_desc &pdesc,
84+
MKLDNNPoolingBwd(const mkldnn::pooling_backward::primitive_desc &pdesc,
9485
bool with_ws);
9586

9687
~MKLDNNPoolingBwd() {}
97-
void SetNewMem(const mxnet::NDArray *workspace,
98-
const mxnet::NDArray &out_grad,
99-
const mkldnn::memory *diff_src_mem);
10088
const mkldnn::pooling_backward &GetBwd();
10189
const mkldnn::pooling_backward::primitive_desc &GetPd();
10290
};
@@ -141,5 +129,5 @@ MKLDNNPoolingFwd &GetPoolingFwd(const PoolingParam &param,
141129
const NDArray &output);
142130
} // namespace op
143131
} // namespace mxnet
144-
#endif // MXNET_USE_MKLDNN == 1
132+
#endif // MXNET_USE_MKLDNN == 100
145133
#endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_POOLING_INL_H_

src/operator/nn/mkldnn/mkldnn_pooling.cc

Lines changed: 71 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
* \author Tao Lv
2424
*/
2525

26-
#if MXNET_USE_MKLDNN == 1
26+
#if MXNET_USE_MKLDNN == 100
2727

2828
#include "./mkldnn_pooling-inl.h"
2929

@@ -34,18 +34,17 @@ void MKLDNNPoolingFwd::Init(const mxnet::NDArray &input, const mxnet::NDArray &o
3434
const int kernel_h, const int kernel_w,
3535
const int stride_h, const int stride_w,
3636
const int padding_t, const int padding_b,
37-
const int padding_l, const int padding_r) {
38-
// mkldnn::memory::desc
39-
auto src_md = input.GetMKLDNNData()->get_primitive_desc().desc();
37+
const int padding_l, const int padding_r,
38+
const bool is_train, const mkldnn::algorithm alg_kind) {
39+
auto src_md = input.GetMKLDNNData()->get_desc();
4040
mkldnn::memory::dims dims = {src_md.data.dims[0],
4141
src_md.data.dims[1],
4242
static_cast<int>(output.shape()[2]),
4343
static_cast<int>(output.shape()[3])};
4444
auto dst_md = mkldnn::memory::desc({dims},
4545
static_cast<mkldnn::memory::data_type>(src_md.data.data_type),
46-
static_cast<mkldnn::memory::format>(src_md.data.format));
46+
mkldnn::memory::format_tag::any);
4747
const mkldnn::engine engine = CpuEngine::Get()->get_engine();
48-
const mkldnn::algorithm alg_kind = this->alg_kind_;
4948
if (alg_kind != mkldnn::algorithm::pooling_max &&
5049
alg_kind != mkldnn::algorithm::pooling_avg &&
5150
alg_kind != mkldnn::algorithm::pooling_avg_include_padding &&
@@ -54,10 +53,10 @@ void MKLDNNPoolingFwd::Init(const mxnet::NDArray &input, const mxnet::NDArray &o
5453
}
5554

5655
mkldnn::prop_kind prop = mkldnn::prop_kind::forward_scoring;
57-
if (this->is_train_ && alg_kind != mkldnn::algorithm::pooling_avg) {
56+
if (is_train && alg_kind != mkldnn::algorithm::pooling_avg) {
5857
prop = mkldnn::prop_kind::forward_training;
5958
}
60-
if (this->is_train_ && prop == mkldnn::prop_kind::forward_scoring) {
59+
if (is_train && prop == mkldnn::prop_kind::forward_scoring) {
6160
LOG(INFO) << "MKLDNN Pooling: training with prop_kind is forward_scoring";
6261
}
6362

@@ -67,49 +66,38 @@ void MKLDNNPoolingFwd::Init(const mxnet::NDArray &input, const mxnet::NDArray &o
6766
const mkldnn::memory::dims kernel = {kernel_h, kernel_w };
6867
// mkldnn::pooling_forward::desc
6968
const auto fwd_desc = mkldnn::pooling_forward::desc(prop, alg_kind, src_md, dst_md,
70-
strides, kernel, pad_l, pad_r,
71-
mkldnn::padding_kind::zero);
69+
strides, kernel, pad_l, pad_r);
7270
this->fwd_pd_.reset(new mkldnn::pooling_forward::primitive_desc(fwd_desc, engine));
73-
this->data_.reset(new mkldnn::memory(input.GetMKLDNNData()->get_primitive_desc()));
74-
this->out_.reset(new mkldnn::memory(this->fwd_pd_->dst_primitive_desc()));
75-
if (this->with_workspace_) {
76-
this->workspace_.reset(new mkldnn::memory(this->fwd_pd_->workspace_primitive_desc()));
77-
this->fwd_.reset(new mkldnn::pooling_forward(*(this->fwd_pd_),
78-
mkldnn::primitive::at(*(this->data_)),
79-
*(this->out_),
80-
*(this->workspace_)));
81-
} else {
82-
this->fwd_.reset(new mkldnn::pooling_forward(*(this->fwd_pd_),
83-
mkldnn::primitive::at(*(this->data_)),
84-
*(this->out_)));
85-
}
71+
this->fwd_.reset(new mkldnn::pooling_forward(*(this->fwd_pd_)));
72+
8673
return;
8774
}
8875

89-
void MKLDNNPoolingFwd::SetNewMem(const NDArray& in_data,
90-
const NDArray& out_data,
91-
const OpReqType& req,
92-
const mxnet::NDArray *workspace) {
93-
auto input_mem = in_data.GetMKLDNNData();
94-
output_mem_t_ = CreateMKLDNNMem(out_data, fwd_pd_->dst_primitive_desc(), req);
95-
// mkldnn::memory
96-
this->data_->set_data_handle(input_mem->get_data_handle());
97-
this->out_->set_data_handle(output_mem_t_.second->get_data_handle());
98-
if (this->with_workspace_ && workspace == nullptr) {
99-
LOG(FATAL) << "MKLDNN Pooling: incorrect workspace input";
100-
}
76+
void MKLDNNPoolingFwd::Execute(const NDArray &in_data,
77+
const OpReqType req,
78+
const NDArray& out_data,
79+
const NDArray *workspace) {
80+
NDArray in_buffer = in_data;
81+
if (in_data.IsView() && in_data.IsMKLDNNData())
82+
in_buffer = in_data.Reorder2Default();
83+
84+
auto input_mem = in_buffer.GetMKLDNNData();
85+
auto output_mem_t_ = CreateMKLDNNMem(out_data, this->fwd_pd_->dst_desc(), req);
86+
87+
mkldnn_args_map_t args = {
88+
{MKLDNN_ARG_SRC, *input_mem },
89+
{MKLDNN_ARG_DST, *(output_mem_t_.second) },
90+
};
10191

10292
if (this->with_workspace_) {
103-
// mkldnn::memory
104-
auto ws_mem = workspace->GetMKLDNNData();
105-
this->workspace_->set_data_handle(ws_mem->get_data_handle());
93+
auto engine = CpuEngine::Get()->get_engine();
94+
auto ws = std::make_shared<mkldnn::memory>((*(this->fwd_pd_)).workspace_desc(),
95+
engine, workspace->GetMKLDNNData()->get_data_handle());
96+
args[MKLDNN_ARG_WORKSPACE] = *ws;
10697
}
107-
}
108-
109-
void MKLDNNPoolingFwd::Execute(const NDArray& out_data) {
11098
if (this->fwd_) {
111-
MKLDNNStream::Get()->RegisterPrim(*(this->fwd_));
112-
CommitOutput(out_data, this->output_mem_t_);
99+
MKLDNNStream::Get()->RegisterPrimArgs(*(this->fwd_), args);
100+
CommitOutput(out_data, output_mem_t_);
113101
MKLDNNStream::Get()->Submit();
114102
} else {
115103
LOG(FATAL) << "MKLDNN Pooling: forward primitive is nullptr";
@@ -143,8 +131,8 @@ static inline int GetPaddingSizeFull(int x, int padl, int padr, int k, int s) {
143131
}
144132

145133
mkldnn::pooling_forward::primitive_desc GetPoolingFwdPdesc(
146-
const PoolingParam &param, const bool is_train, const memory::desc &data_md,
147-
const memory::desc &out_md) {
134+
const PoolingParam &param, const bool is_train, const mkldnn::memory::desc &data_md,
135+
const mkldnn::memory::desc &out_md) {
148136
CHECK_EQ(param.kernel.ndim(), 2) << "Not Implemented";
149137
int kernel_h_, kernel_w_;
150138
if (param.global_pool) {
@@ -183,19 +171,18 @@ mkldnn::pooling_forward::primitive_desc GetPoolingFwdPdesc(
183171

184172
const mkldnn::algorithm alg = GetMKLDNNPoolAlgo(param);
185173
mkldnn::prop_kind kind = mkldnn::prop_kind::forward_scoring;
186-
if (is_train && alg != algorithm::pooling_avg) {
174+
if (is_train && alg != mkldnn::algorithm::pooling_avg) {
187175
kind = mkldnn::prop_kind::forward_training;
188176
}
189177

190-
const pooling_forward::desc poolingFwd_desc(kind, alg, data_md, out_md,
178+
const mkldnn::pooling_forward::desc poolingFwd_desc(kind, alg, data_md, out_md,
191179
{static_cast<int>(stride_h_),
192180
static_cast<int>(stride_w_)},
193181
{kernel_h_, kernel_w_},
194182
{static_cast<int>(pad_t_),
195183
static_cast<int>(pad_l_)},
196184
{static_cast<int>(pad_b_),
197-
static_cast<int>(pad_r_)},
198-
padding_kind::zero);
185+
static_cast<int>(pad_r_)});
199186
return mkldnn::pooling_forward::primitive_desc(poolingFwd_desc, engine);
200187
}
201188

@@ -223,7 +210,7 @@ MKLDNNPoolingFwd &GetPoolingFwd(const PoolingParam &param,
223210
auto it = pooling_fwds.find(key);
224211
if (it == pooling_fwds.end()) {
225212
CHECK_EQ(param.kernel.ndim(), 2) << "Not Implemented";
226-
auto data_md = data.GetMKLDNNData()->get_primitive_desc().desc();
213+
auto data_md = data.GetMKLDNNData()->get_desc();
227214
int kernel_h_, kernel_w_;
228215
if (param.global_pool) {
229216
kernel_h_ = data_md.data.dims[2];
@@ -270,42 +257,14 @@ void MKLDNNPoolingCompute(const OpContext &ctx, const PoolingParam &param,
270257
const NDArray &in_data, const OpReqType req,
271258
const NDArray &out_data, const NDArray *workspace) {
272259
auto &fwd = GetPoolingFwd(param, ctx.is_train, in_data, out_data);
273-
fwd.SetNewMem(in_data, out_data, req, workspace);
274-
fwd.Execute(out_data);
260+
fwd.Execute(in_data, req, out_data, workspace);
275261
}
276262

277263
MKLDNNPoolingBwd::MKLDNNPoolingBwd(
278-
const pooling_backward::primitive_desc &pdesc, bool with_ws)
279-
: with_workspace(with_ws), pd(pdesc) {}
280-
281-
void MKLDNNPoolingBwd::SetNewMem(const mxnet::NDArray *workspace,
282-
const mxnet::NDArray &out_grad,
283-
const mkldnn::memory *diff_src_mem) {
284-
if (bwd == nullptr) {
285-
diff_dst.reset(
286-
new mkldnn::memory(out_grad.GetMKLDNNData()->get_primitive_desc(),
287-
out_grad.GetMKLDNNData()->get_data_handle()));
288-
diff_src.reset(new mkldnn::memory(pd.diff_src_primitive_desc(),
289-
diff_src_mem->get_data_handle()));
290-
if (with_workspace) {
291-
CHECK(workspace != nullptr);
292-
ws.reset(
293-
new mkldnn::memory(workspace->GetMKLDNNData()->get_primitive_desc(),
294-
workspace->GetMKLDNNData()->get_data_handle()));
295-
bwd.reset(
296-
new pooling_backward(pd, *diff_dst, primitive::at(*ws), *diff_src));
297-
} else {
298-
bwd.reset(new pooling_backward(pd, *diff_dst, *diff_src));
299-
}
300-
} else {
301-
diff_dst->set_data_handle(out_grad.GetMKLDNNData()->get_data_handle());
302-
diff_src->set_data_handle(diff_src_mem->get_data_handle());
303-
if (with_workspace) {
304-
CHECK(workspace != nullptr);
305-
ws->set_data_handle(workspace->GetMKLDNNData()->get_data_handle());
264+
const mkldnn::pooling_backward::primitive_desc &pdesc, bool with_ws)
265+
: with_workspace(with_ws), pd(pdesc) {
266+
bwd = std::make_shared<mkldnn::pooling_backward>(pd);
306267
}
307-
}
308-
}
309268

310269
const mkldnn::pooling_backward &MKLDNNPoolingBwd::GetBwd() {
311270
return *this->bwd;
@@ -333,27 +292,29 @@ MKLDNNPoolingBwd &GetPoolingBwd(const PoolingParam &param,
333292

334293
auto it = pooling_bwds.find(key);
335294
if (it == pooling_bwds.end()) {
336-
auto diff_dst_mem = out_grad.GetMKLDNNData();
295+
NDArray diff_dst_buff = out_grad;
296+
if (in_data.IsMKLDNNData() == false && diff_dst_buff.IsMKLDNNData() == true) {
297+
diff_dst_buff = out_grad.Reorder2Default();
298+
}
299+
auto diff_dst_mem = diff_dst_buff.GetMKLDNNData();
337300
auto input_mem = in_data.GetMKLDNNData();
338-
mkldnn::memory::primitive_desc data_mpd = input_mem->get_primitive_desc();
339-
const mkldnn::memory::desc data_md = data_mpd.desc();
340-
const memory::dims dims = {data_md.data.dims[0], data_md.data.dims[1],
301+
const mkldnn::memory::desc data_md = input_mem->get_desc();
302+
const mkldnn::memory::dims dims = {data_md.data.dims[0], data_md.data.dims[1],
341303
static_cast<int>(out_grad.shape()[2]),
342304
static_cast<int>(out_grad.shape()[3])};
343-
const memory::desc out_md(
344-
{dims}, static_cast<memory::data_type>(data_md.data.data_type),
345-
static_cast<memory::format>(data_md.data.format));
305+
const mkldnn::memory::desc out_md(
306+
{dims}, static_cast<mkldnn::memory::data_type>(data_md.data.data_type),
307+
mkldnn::memory::format_tag::any);
346308
auto fwd_pd = GetPoolingFwdPdesc(param, true, data_md, out_md);
347-
348309
const mkldnn::memory::desc diff_md =
349-
diff_dst_mem->get_primitive_desc().desc();
350-
const memory::dims dims1 = {diff_md.data.dims[0], diff_md.data.dims[1],
310+
diff_dst_mem->get_desc();
311+
const mkldnn::memory::dims dims1 = {diff_md.data.dims[0], diff_md.data.dims[1],
351312
static_cast<int>(in_grad.shape()[2]),
352313
static_cast<int>(in_grad.shape()[3])};
353-
const memory::desc diff_in_md(
354-
{dims1}, static_cast<memory::data_type>(diff_md.data.data_type),
355-
static_cast<memory::format>(diff_md.data.format));
356-
const mkldnn::engine cpu_engine = data_mpd.get_engine();
314+
const mkldnn::memory::desc diff_in_md(
315+
{dims1}, static_cast<mkldnn::memory::data_type>(diff_md.data.data_type),
316+
mkldnn::memory::format_tag::any);
317+
const mkldnn::engine cpu_engine = CpuEngine::Get()->get_engine();;
357318
const mkldnn::algorithm alg = GetMKLDNNPoolAlgo(param);
358319

359320
int kernel_h_, kernel_w_;
@@ -379,11 +340,10 @@ MKLDNNPoolingBwd &GetPoolingBwd(const PoolingParam &param,
379340
stride_h_ = stride_w_ = 1;
380341
}
381342

382-
const pooling_backward::desc desc(
343+
const mkldnn::pooling_backward::desc desc(
383344
alg, diff_in_md, diff_md, {stride_h_, stride_w_},
384-
{kernel_h_, kernel_w_}, {pad_t_, pad_l_}, {pad_b_, pad_r_},
385-
mkldnn::padding_kind::zero);
386-
const auto pdesc = pooling_backward::primitive_desc(desc, cpu_engine, fwd_pd);
345+
{kernel_h_, kernel_w_}, {pad_t_, pad_l_}, {pad_b_, pad_r_});
346+
const auto pdesc = mkldnn::pooling_backward::primitive_desc(desc, cpu_engine, fwd_pd);
387347
MKLDNNPoolingBwd bwd(pdesc, with_workspace);
388348
it = AddToCache(&pooling_bwds, key, bwd);
389349
}
@@ -401,14 +361,21 @@ void MKLDNNPoolingGradCompute(const OpContext &ctx, const PoolingParam &param,
401361

402362
auto &bwd = GetPoolingBwd(param, in_data, in_grad, out_grad);
403363
auto diff_src_mem =
404-
CreateMKLDNNMem(in_grad, bwd.pd.diff_src_primitive_desc(), req);
364+
CreateMKLDNNMem(in_grad, bwd.pd.diff_src_desc(), req);
365+
366+
mkldnn_args_map_t args = {
367+
{MKLDNN_ARG_DIFF_DST, *(out_grad.GetMKLDNNData())},
368+
{MKLDNN_ARG_DIFF_SRC, *diff_src_mem.second },
369+
};
370+
if (MKLDNNRequireWorkspace(param) && workspace != nullptr) {
371+
args[MKLDNN_ARG_WORKSPACE] = *(workspace->GetMKLDNNData());
372+
}
405373

406-
bwd.SetNewMem(workspace, out_grad, diff_src_mem.second);
407-
MKLDNNStream::Get()->RegisterPrim(bwd.GetBwd());
374+
MKLDNNStream::Get()->RegisterPrimArgs(bwd.GetBwd(), args);
408375
CommitOutput(in_grad, diff_src_mem);
409376
MKLDNNStream::Get()->Submit();
410377
}
411378

412379
} // namespace op
413380
} // namespace mxnet
414-
#endif // MXNET_USE_MKLDNN == 1
381+
#endif // MXNET_USE_MKLDNN == 100

0 commit comments

Comments
 (0)