2323 * \author Tao Lv
2424*/
2525
26- #if MXNET_USE_MKLDNN == 1
26+ #if MXNET_USE_MKLDNN == 100
2727
2828#include " ./mkldnn_pooling-inl.h"
2929
@@ -34,18 +34,17 @@ void MKLDNNPoolingFwd::Init(const mxnet::NDArray &input, const mxnet::NDArray &o
3434 const int kernel_h, const int kernel_w,
3535 const int stride_h, const int stride_w,
3636 const int padding_t , const int padding_b,
37- const int padding_l, const int padding_r) {
38- // mkldnn::memory::desc
39- auto src_md = input.GetMKLDNNData ()->get_primitive_desc (). desc ();
37+ const int padding_l, const int padding_r,
38+ const bool is_train, const mkldnn::algorithm alg_kind) {
39+ auto src_md = input.GetMKLDNNData ()->get_desc ();
4040 mkldnn::memory::dims dims = {src_md.data .dims [0 ],
4141 src_md.data .dims [1 ],
4242 static_cast <int >(output.shape ()[2 ]),
4343 static_cast <int >(output.shape ()[3 ])};
4444 auto dst_md = mkldnn::memory::desc ({dims},
4545 static_cast <mkldnn::memory::data_type>(src_md.data .data_type ),
46- static_cast < mkldnn::memory::format>(src_md. data . format ) );
46+ mkldnn::memory::format_tag::any );
4747 const mkldnn::engine engine = CpuEngine::Get ()->get_engine ();
48- const mkldnn::algorithm alg_kind = this ->alg_kind_ ;
4948 if (alg_kind != mkldnn::algorithm::pooling_max &&
5049 alg_kind != mkldnn::algorithm::pooling_avg &&
5150 alg_kind != mkldnn::algorithm::pooling_avg_include_padding &&
@@ -54,10 +53,10 @@ void MKLDNNPoolingFwd::Init(const mxnet::NDArray &input, const mxnet::NDArray &o
5453 }
5554
5655 mkldnn::prop_kind prop = mkldnn::prop_kind::forward_scoring;
57- if (this -> is_train_ && alg_kind != mkldnn::algorithm::pooling_avg) {
56+ if (is_train && alg_kind != mkldnn::algorithm::pooling_avg) {
5857 prop = mkldnn::prop_kind::forward_training;
5958 }
60- if (this -> is_train_ && prop == mkldnn::prop_kind::forward_scoring) {
59+ if (is_train && prop == mkldnn::prop_kind::forward_scoring) {
6160 LOG (INFO) << " MKLDNN Pooling: training with prop_kind is forward_scoring" ;
6261 }
6362
@@ -67,49 +66,38 @@ void MKLDNNPoolingFwd::Init(const mxnet::NDArray &input, const mxnet::NDArray &o
6766 const mkldnn::memory::dims kernel = {kernel_h, kernel_w };
6867 // mkldnn::pooling_forward::desc
6968 const auto fwd_desc = mkldnn::pooling_forward::desc (prop, alg_kind, src_md, dst_md,
70- strides, kernel, pad_l, pad_r,
71- mkldnn::padding_kind::zero);
69+ strides, kernel, pad_l, pad_r);
7270 this ->fwd_pd_ .reset (new mkldnn::pooling_forward::primitive_desc (fwd_desc, engine));
73- this ->data_ .reset (new mkldnn::memory (input.GetMKLDNNData ()->get_primitive_desc ()));
74- this ->out_ .reset (new mkldnn::memory (this ->fwd_pd_ ->dst_primitive_desc ()));
75- if (this ->with_workspace_ ) {
76- this ->workspace_ .reset (new mkldnn::memory (this ->fwd_pd_ ->workspace_primitive_desc ()));
77- this ->fwd_ .reset (new mkldnn::pooling_forward (*(this ->fwd_pd_ ),
78- mkldnn::primitive::at (*(this ->data_ )),
79- *(this ->out_ ),
80- *(this ->workspace_ )));
81- } else {
82- this ->fwd_ .reset (new mkldnn::pooling_forward (*(this ->fwd_pd_ ),
83- mkldnn::primitive::at (*(this ->data_ )),
84- *(this ->out_ )));
85- }
71+ this ->fwd_ .reset (new mkldnn::pooling_forward (*(this ->fwd_pd_ )));
72+
8673 return ;
8774}
8875
89- void MKLDNNPoolingFwd::SetNewMem (const NDArray& in_data,
90- const NDArray& out_data,
91- const OpReqType& req,
92- const mxnet::NDArray *workspace) {
93- auto input_mem = in_data.GetMKLDNNData ();
94- output_mem_t_ = CreateMKLDNNMem (out_data, fwd_pd_->dst_primitive_desc (), req);
95- // mkldnn::memory
96- this ->data_ ->set_data_handle (input_mem->get_data_handle ());
97- this ->out_ ->set_data_handle (output_mem_t_.second ->get_data_handle ());
98- if (this ->with_workspace_ && workspace == nullptr ) {
99- LOG (FATAL) << " MKLDNN Pooling: incorrect workspace input" ;
100- }
76+ void MKLDNNPoolingFwd::Execute (const NDArray &in_data,
77+ const OpReqType req,
78+ const NDArray& out_data,
79+ const NDArray *workspace) {
80+ NDArray in_buffer = in_data;
81+ if (in_data.IsView () && in_data.IsMKLDNNData ())
82+ in_buffer = in_data.Reorder2Default ();
83+
84+ auto input_mem = in_buffer.GetMKLDNNData ();
85+ auto output_mem_t_ = CreateMKLDNNMem (out_data, this ->fwd_pd_ ->dst_desc (), req);
86+
87+ mkldnn_args_map_t args = {
88+ {MKLDNN_ARG_SRC, *input_mem },
89+ {MKLDNN_ARG_DST, *(output_mem_t_.second ) },
90+ };
10191
10292 if (this ->with_workspace_ ) {
103- // mkldnn::memory
104- auto ws_mem = workspace->GetMKLDNNData ();
105- this ->workspace_ ->set_data_handle (ws_mem->get_data_handle ());
93+ auto engine = CpuEngine::Get ()->get_engine ();
94+ auto ws = std::make_shared<mkldnn::memory>((*(this ->fwd_pd_ )).workspace_desc (),
95+ engine, workspace->GetMKLDNNData ()->get_data_handle ());
96+ args[MKLDNN_ARG_WORKSPACE] = *ws;
10697 }
107- }
108-
109- void MKLDNNPoolingFwd::Execute (const NDArray& out_data) {
11098 if (this ->fwd_ ) {
111- MKLDNNStream::Get ()->RegisterPrim (*(this ->fwd_ ));
112- CommitOutput (out_data, this -> output_mem_t_ );
99+ MKLDNNStream::Get ()->RegisterPrimArgs (*(this ->fwd_ ), args );
100+ CommitOutput (out_data, output_mem_t_);
113101 MKLDNNStream::Get ()->Submit ();
114102 } else {
115103 LOG (FATAL) << " MKLDNN Pooling: forward primitive is nullptr" ;
@@ -143,8 +131,8 @@ static inline int GetPaddingSizeFull(int x, int padl, int padr, int k, int s) {
143131}
144132
145133mkldnn::pooling_forward::primitive_desc GetPoolingFwdPdesc (
146- const PoolingParam ¶m, const bool is_train, const memory::desc &data_md,
147- const memory::desc &out_md) {
134+ const PoolingParam ¶m, const bool is_train, const mkldnn:: memory::desc &data_md,
135+ const mkldnn:: memory::desc &out_md) {
148136 CHECK_EQ (param.kernel .ndim (), 2 ) << " Not Implemented" ;
149137 int kernel_h_, kernel_w_;
150138 if (param.global_pool ) {
@@ -183,19 +171,18 @@ mkldnn::pooling_forward::primitive_desc GetPoolingFwdPdesc(
183171
184172 const mkldnn::algorithm alg = GetMKLDNNPoolAlgo (param);
185173 mkldnn::prop_kind kind = mkldnn::prop_kind::forward_scoring;
186- if (is_train && alg != algorithm::pooling_avg) {
174+ if (is_train && alg != mkldnn:: algorithm::pooling_avg) {
187175 kind = mkldnn::prop_kind::forward_training;
188176 }
189177
190- const pooling_forward::desc poolingFwd_desc (kind, alg, data_md, out_md,
178+ const mkldnn:: pooling_forward::desc poolingFwd_desc (kind, alg, data_md, out_md,
191179 {static_cast <int >(stride_h_),
192180 static_cast <int >(stride_w_)},
193181 {kernel_h_, kernel_w_},
194182 {static_cast <int >(pad_t_),
195183 static_cast <int >(pad_l_)},
196184 {static_cast <int >(pad_b_),
197- static_cast <int >(pad_r_)},
198- padding_kind::zero);
185+ static_cast <int >(pad_r_)});
199186 return mkldnn::pooling_forward::primitive_desc (poolingFwd_desc, engine);
200187}
201188
@@ -223,7 +210,7 @@ MKLDNNPoolingFwd &GetPoolingFwd(const PoolingParam ¶m,
223210 auto it = pooling_fwds.find (key);
224211 if (it == pooling_fwds.end ()) {
225212 CHECK_EQ (param.kernel .ndim (), 2 ) << " Not Implemented" ;
226- auto data_md = data.GetMKLDNNData ()->get_primitive_desc (). desc ();
213+ auto data_md = data.GetMKLDNNData ()->get_desc ();
227214 int kernel_h_, kernel_w_;
228215 if (param.global_pool ) {
229216 kernel_h_ = data_md.data .dims [2 ];
@@ -270,42 +257,14 @@ void MKLDNNPoolingCompute(const OpContext &ctx, const PoolingParam ¶m,
270257 const NDArray &in_data, const OpReqType req,
271258 const NDArray &out_data, const NDArray *workspace) {
272259 auto &fwd = GetPoolingFwd (param, ctx.is_train , in_data, out_data);
273- fwd.SetNewMem (in_data, out_data, req, workspace);
274- fwd.Execute (out_data);
260+ fwd.Execute (in_data, req, out_data, workspace);
275261}
276262
277263MKLDNNPoolingBwd::MKLDNNPoolingBwd (
278- const pooling_backward::primitive_desc &pdesc, bool with_ws)
279- : with_workspace(with_ws), pd(pdesc) {}
280-
281- void MKLDNNPoolingBwd::SetNewMem (const mxnet::NDArray *workspace,
282- const mxnet::NDArray &out_grad,
283- const mkldnn::memory *diff_src_mem) {
284- if (bwd == nullptr ) {
285- diff_dst.reset (
286- new mkldnn::memory (out_grad.GetMKLDNNData ()->get_primitive_desc (),
287- out_grad.GetMKLDNNData ()->get_data_handle ()));
288- diff_src.reset (new mkldnn::memory (pd.diff_src_primitive_desc (),
289- diff_src_mem->get_data_handle ()));
290- if (with_workspace) {
291- CHECK (workspace != nullptr );
292- ws.reset (
293- new mkldnn::memory (workspace->GetMKLDNNData ()->get_primitive_desc (),
294- workspace->GetMKLDNNData ()->get_data_handle ()));
295- bwd.reset (
296- new pooling_backward (pd, *diff_dst, primitive::at (*ws), *diff_src));
297- } else {
298- bwd.reset (new pooling_backward (pd, *diff_dst, *diff_src));
299- }
300- } else {
301- diff_dst->set_data_handle (out_grad.GetMKLDNNData ()->get_data_handle ());
302- diff_src->set_data_handle (diff_src_mem->get_data_handle ());
303- if (with_workspace) {
304- CHECK (workspace != nullptr );
305- ws->set_data_handle (workspace->GetMKLDNNData ()->get_data_handle ());
264+ const mkldnn::pooling_backward::primitive_desc &pdesc, bool with_ws)
265+ : with_workspace(with_ws), pd(pdesc) {
266+ bwd = std::make_shared<mkldnn::pooling_backward>(pd);
306267 }
307- }
308- }
309268
310269const mkldnn::pooling_backward &MKLDNNPoolingBwd::GetBwd () {
311270 return *this ->bwd ;
@@ -333,27 +292,29 @@ MKLDNNPoolingBwd &GetPoolingBwd(const PoolingParam ¶m,
333292
334293 auto it = pooling_bwds.find (key);
335294 if (it == pooling_bwds.end ()) {
336- auto diff_dst_mem = out_grad.GetMKLDNNData ();
295+ NDArray diff_dst_buff = out_grad;
296+ if (in_data.IsMKLDNNData () == false && diff_dst_buff.IsMKLDNNData () == true ) {
297+ diff_dst_buff = out_grad.Reorder2Default ();
298+ }
299+ auto diff_dst_mem = diff_dst_buff.GetMKLDNNData ();
337300 auto input_mem = in_data.GetMKLDNNData ();
338- mkldnn::memory::primitive_desc data_mpd = input_mem->get_primitive_desc ();
339- const mkldnn::memory::desc data_md = data_mpd.desc ();
340- const memory::dims dims = {data_md.data .dims [0 ], data_md.data .dims [1 ],
301+ const mkldnn::memory::desc data_md = input_mem->get_desc ();
302+ const mkldnn::memory::dims dims = {data_md.data .dims [0 ], data_md.data .dims [1 ],
341303 static_cast <int >(out_grad.shape ()[2 ]),
342304 static_cast <int >(out_grad.shape ()[3 ])};
343- const memory::desc out_md (
344- {dims}, static_cast <memory::data_type>(data_md.data .data_type ),
345- static_cast < memory::format>(data_md. data . format ) );
305+ const mkldnn:: memory::desc out_md (
306+ {dims}, static_cast <mkldnn:: memory::data_type>(data_md.data .data_type ),
307+ mkldnn:: memory::format_tag::any );
346308 auto fwd_pd = GetPoolingFwdPdesc (param, true , data_md, out_md);
347-
348309 const mkldnn::memory::desc diff_md =
349- diff_dst_mem->get_primitive_desc (). desc ();
350- const memory::dims dims1 = {diff_md.data .dims [0 ], diff_md.data .dims [1 ],
310+ diff_dst_mem->get_desc ();
311+ const mkldnn:: memory::dims dims1 = {diff_md.data .dims [0 ], diff_md.data .dims [1 ],
351312 static_cast <int >(in_grad.shape ()[2 ]),
352313 static_cast <int >(in_grad.shape ()[3 ])};
353- const memory::desc diff_in_md (
354- {dims1}, static_cast <memory::data_type>(diff_md.data .data_type ),
355- static_cast < memory::format>(diff_md. data . format ) );
356- const mkldnn::engine cpu_engine = data_mpd. get_engine ();
314+ const mkldnn:: memory::desc diff_in_md (
315+ {dims1}, static_cast <mkldnn:: memory::data_type>(diff_md.data .data_type ),
316+ mkldnn:: memory::format_tag::any );
317+ const mkldnn::engine cpu_engine = CpuEngine::Get ()-> get_engine (); ;
357318 const mkldnn::algorithm alg = GetMKLDNNPoolAlgo (param);
358319
359320 int kernel_h_, kernel_w_;
@@ -379,11 +340,10 @@ MKLDNNPoolingBwd &GetPoolingBwd(const PoolingParam ¶m,
379340 stride_h_ = stride_w_ = 1 ;
380341 }
381342
382- const pooling_backward::desc desc (
343+ const mkldnn:: pooling_backward::desc desc (
383344 alg, diff_in_md, diff_md, {stride_h_, stride_w_},
384- {kernel_h_, kernel_w_}, {pad_t_, pad_l_}, {pad_b_, pad_r_},
385- mkldnn::padding_kind::zero);
386- const auto pdesc = pooling_backward::primitive_desc (desc, cpu_engine, fwd_pd);
345+ {kernel_h_, kernel_w_}, {pad_t_, pad_l_}, {pad_b_, pad_r_});
346+ const auto pdesc = mkldnn::pooling_backward::primitive_desc (desc, cpu_engine, fwd_pd);
387347 MKLDNNPoolingBwd bwd (pdesc, with_workspace);
388348 it = AddToCache (&pooling_bwds, key, bwd);
389349 }
@@ -401,14 +361,21 @@ void MKLDNNPoolingGradCompute(const OpContext &ctx, const PoolingParam ¶m,
401361
402362 auto &bwd = GetPoolingBwd (param, in_data, in_grad, out_grad);
403363 auto diff_src_mem =
404- CreateMKLDNNMem (in_grad, bwd.pd .diff_src_primitive_desc (), req);
364+ CreateMKLDNNMem (in_grad, bwd.pd .diff_src_desc (), req);
365+
366+ mkldnn_args_map_t args = {
367+ {MKLDNN_ARG_DIFF_DST, *(out_grad.GetMKLDNNData ())},
368+ {MKLDNN_ARG_DIFF_SRC, *diff_src_mem.second },
369+ };
370+ if (MKLDNNRequireWorkspace (param) && workspace != nullptr ) {
371+ args[MKLDNN_ARG_WORKSPACE] = *(workspace->GetMKLDNNData ());
372+ }
405373
406- bwd.SetNewMem (workspace, out_grad, diff_src_mem.second );
407- MKLDNNStream::Get ()->RegisterPrim (bwd.GetBwd ());
374+ MKLDNNStream::Get ()->RegisterPrimArgs (bwd.GetBwd (), args);
408375 CommitOutput (in_grad, diff_src_mem);
409376 MKLDNNStream::Get ()->Submit ();
410377}
411378
412379} // namespace op
413380} // namespace mxnet
414- #endif // MXNET_USE_MKLDNN == 1
381+ #endif // MXNET_USE_MKLDNN == 100
0 commit comments