2323 * \brief
2424 */
2525
26- #if MXNET_USE_MKLDNN == 1
26+ #if MXNET_USE_MKLDNN == 100
2727#include " ../../nn/mkldnn/mkldnn_concat-inl.h"
2828#include " ../quantization_utils.h"
2929
@@ -60,7 +60,7 @@ static void MKLDNNQuantizedConcatForward(const nnvm::NodeAttrs& attrs, const OpC
6060 out_data[quantized_concat_enum::kMin ].data ().dptr <float >()[0 ] = output_neg_min;
6161 out_data[quantized_concat_enum::kMax ].data ().dptr <float >()[0 ] = output_pos_max;
6262 auto out_scale = GetScale (out_data[quantized_concat_enum::kOut ], output_neg_min, output_pos_max);
63- std::vector<mkldnn::memory::primitive_desc > data_md;
63+ std::vector<mkldnn::memory::desc > data_md;
6464 std::vector<const mkldnn::memory*> data_mem;
6565 // new_data_mem is for auto-free new created mkldnn memory
6666 std::vector<std::shared_ptr<mkldnn::memory>> new_data_mem;
@@ -71,36 +71,37 @@ static void MKLDNNQuantizedConcatForward(const nnvm::NodeAttrs& attrs, const OpC
7171 CHECK (in_data[i].dtype () == out_dtype);
7272 auto mem = in_data[i].GetMKLDNNData ();
7373 data_mem.push_back (mem);
74- data_md.push_back (mem->get_primitive_desc ());
74+ data_md.push_back (mem->get_desc ());
7575 } else {
7676 auto mem = in_data[i].GetMKLDNNData ();
77- auto pd = mem->get_primitive_desc ();
77+ auto mem_desc = mem->get_desc ();
7878 if (in_data[i].dtype () != out_dtype) {
79- auto mem_desc = pd.desc ();
80- mkldnn::memory::desc new_md (
81- mkldnn::memory::dims (mem_desc.data .dims , mem_desc.data .dims + mem_desc.data .ndims ),
82- get_mkldnn_type (out_dtype), static_cast <mkldnn::memory::format>(mem_desc.data .format ));
83- pd = mkldnn::memory::primitive_desc (new_md, CpuEngine::Get ()->get_engine ());
79+ mem_desc.data .data_type = static_cast <mkldnn_data_type_t >(get_mkldnn_type (out_dtype));
8480 }
85- const auto rescaled_mem = std::make_shared<mkldnn::memory>(pd);
81+ const auto rescaled_mem =
82+ std::make_shared<mkldnn::memory>(mem_desc, CpuEngine::Get ()->get_engine ());
8683 new_data_mem.push_back (rescaled_mem);
8784 std::vector<float > reorder_scale = {out_scale / i_scale};
88- primitive_attr reorder_attr;
89- reorder_attr.set_int_output_round_mode (round_mode::round_nearest);
85+ mkldnn::primitive_attr reorder_attr;
9086 reorder_attr.set_output_scales (0 , reorder_scale);
91- const auto reorder_pd =
92- mkldnn::reorder::primitive_desc (mem->get_primitive_desc (), pd, reorder_attr);
93- MKLDNNStream::Get ()->RegisterPrim (mkldnn::reorder (reorder_pd, *mem, *rescaled_mem));
87+ const auto reorder_pd = mkldnn::reorder::primitive_desc (*mem, *rescaled_mem, reorder_attr);
88+ mkldnn_args_map_t reorder_args;
89+ reorder_args[MKLDNN_ARG_SRC] = *mem;
90+ reorder_args[MKLDNN_ARG_DST] = *rescaled_mem;
91+ MKLDNNStream::Get ()->RegisterPrimArgs (mkldnn::reorder (reorder_pd), reorder_args);
9492 data_mem.push_back (rescaled_mem.get ());
95- data_md.push_back (pd );
93+ data_md.push_back (mem_desc );
9694 }
9795 }
9896 MKLDNNConcatFwd& fwd = GetConcatForward (param_.dim , in_data, data_md);
99- mxnet::mkldnn_output_t out_mem =
100- CreateMKLDNNMem (out_data[quantized_concat_enum::kOut ], fwd.fwd_pd .dst_primitive_desc (),
101- req[concat_enum::kOut ]);
102- fwd.SetNewMem (data_mem, *out_mem.second );
103- MKLDNNStream::Get ()->RegisterPrim (fwd.GetFwd ());
97+ mxnet::mkldnn_output_t out_mem = CreateMKLDNNMem (out_data[quantized_concat_enum::kOut ],
98+ fwd.fwd_pd .dst_desc (), req[concat_enum::kOut ]);
99+ mkldnn_args_map_t net_args;
100+ net_args[MKLDNN_ARG_DST] = *out_mem.second ;
101+ for (int i = 0 ; i < param_.num_args ; i++) {
102+ net_args[MKLDNN_ARG_MULTIPLE_SRC + i] = *data_mem[i];
103+ }
104+ MKLDNNStream::Get ()->RegisterPrimArgs (fwd.GetFwd (), net_args);
104105 CommitOutput (out_data[concat_enum::kOut ], out_mem);
105106 MKLDNNStream::Get ()->Submit ();
106107}
@@ -126,4 +127,4 @@ NNVM_REGISTER_OP(_contrib_quantized_concat)
126127} // namespace op
127128} // namespace mxnet
128129
129- #endif // MXNET_USE_MKLDNN == 1
130+ #endif // MXNET_USE_MKLDNN == 100
0 commit comments