Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 4d9a53e

Browse files
ElaineBaopengzhao-intel
authored andcommitted
[mkldnn-1.0] upgrade int8 concat to MKLDNN1.0 (#16466)
* [mkldnn-1.0] upgrade int8 concat to MKLDNN1.0 * fix lint * use mkldnn_args_map_t * update dict usage style * retrigger CI * retrigger CI again * retrigger CI again 2
1 parent 43e35a9 commit 4d9a53e

File tree

1 file changed

+23
-22
lines changed

1 file changed

+23
-22
lines changed

src/operator/quantization/mkldnn/mkldnn_quantized_concat.cc

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
* \brief
2424
*/
2525

26-
#if MXNET_USE_MKLDNN == 1
26+
#if MXNET_USE_MKLDNN == 100
2727
#include "../../nn/mkldnn/mkldnn_concat-inl.h"
2828
#include "../quantization_utils.h"
2929

@@ -60,7 +60,7 @@ static void MKLDNNQuantizedConcatForward(const nnvm::NodeAttrs& attrs, const OpC
6060
out_data[quantized_concat_enum::kMin].data().dptr<float>()[0] = output_neg_min;
6161
out_data[quantized_concat_enum::kMax].data().dptr<float>()[0] = output_pos_max;
6262
auto out_scale = GetScale(out_data[quantized_concat_enum::kOut], output_neg_min, output_pos_max);
63-
std::vector<mkldnn::memory::primitive_desc> data_md;
63+
std::vector<mkldnn::memory::desc> data_md;
6464
std::vector<const mkldnn::memory*> data_mem;
6565
// new_data_mem is for auto-free new created mkldnn memory
6666
std::vector<std::shared_ptr<mkldnn::memory>> new_data_mem;
@@ -71,36 +71,37 @@ static void MKLDNNQuantizedConcatForward(const nnvm::NodeAttrs& attrs, const OpC
7171
CHECK(in_data[i].dtype() == out_dtype);
7272
auto mem = in_data[i].GetMKLDNNData();
7373
data_mem.push_back(mem);
74-
data_md.push_back(mem->get_primitive_desc());
74+
data_md.push_back(mem->get_desc());
7575
} else {
7676
auto mem = in_data[i].GetMKLDNNData();
77-
auto pd = mem->get_primitive_desc();
77+
auto mem_desc = mem->get_desc();
7878
if (in_data[i].dtype() != out_dtype) {
79-
auto mem_desc = pd.desc();
80-
mkldnn::memory::desc new_md(
81-
mkldnn::memory::dims(mem_desc.data.dims, mem_desc.data.dims + mem_desc.data.ndims),
82-
get_mkldnn_type(out_dtype), static_cast<mkldnn::memory::format>(mem_desc.data.format));
83-
pd = mkldnn::memory::primitive_desc(new_md, CpuEngine::Get()->get_engine());
79+
mem_desc.data.data_type = static_cast<mkldnn_data_type_t>(get_mkldnn_type(out_dtype));
8480
}
85-
const auto rescaled_mem = std::make_shared<mkldnn::memory>(pd);
81+
const auto rescaled_mem =
82+
std::make_shared<mkldnn::memory>(mem_desc, CpuEngine::Get()->get_engine());
8683
new_data_mem.push_back(rescaled_mem);
8784
std::vector<float> reorder_scale = {out_scale / i_scale};
88-
primitive_attr reorder_attr;
89-
reorder_attr.set_int_output_round_mode(round_mode::round_nearest);
85+
mkldnn::primitive_attr reorder_attr;
9086
reorder_attr.set_output_scales(0, reorder_scale);
91-
const auto reorder_pd =
92-
mkldnn::reorder::primitive_desc(mem->get_primitive_desc(), pd, reorder_attr);
93-
MKLDNNStream::Get()->RegisterPrim(mkldnn::reorder(reorder_pd, *mem, *rescaled_mem));
87+
const auto reorder_pd = mkldnn::reorder::primitive_desc(*mem, *rescaled_mem, reorder_attr);
88+
mkldnn_args_map_t reorder_args;
89+
reorder_args[MKLDNN_ARG_SRC] = *mem;
90+
reorder_args[MKLDNN_ARG_DST] = *rescaled_mem;
91+
MKLDNNStream::Get()->RegisterPrimArgs(mkldnn::reorder(reorder_pd), reorder_args);
9492
data_mem.push_back(rescaled_mem.get());
95-
data_md.push_back(pd);
93+
data_md.push_back(mem_desc);
9694
}
9795
}
9896
MKLDNNConcatFwd& fwd = GetConcatForward(param_.dim, in_data, data_md);
99-
mxnet::mkldnn_output_t out_mem =
100-
CreateMKLDNNMem(out_data[quantized_concat_enum::kOut], fwd.fwd_pd.dst_primitive_desc(),
101-
req[concat_enum::kOut]);
102-
fwd.SetNewMem(data_mem, *out_mem.second);
103-
MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd());
97+
mxnet::mkldnn_output_t out_mem = CreateMKLDNNMem(out_data[quantized_concat_enum::kOut],
98+
fwd.fwd_pd.dst_desc(), req[concat_enum::kOut]);
99+
mkldnn_args_map_t net_args;
100+
net_args[MKLDNN_ARG_DST] = *out_mem.second;
101+
for (int i = 0; i < param_.num_args; i++) {
102+
net_args[MKLDNN_ARG_MULTIPLE_SRC + i] = *data_mem[i];
103+
}
104+
MKLDNNStream::Get()->RegisterPrimArgs(fwd.GetFwd(), net_args);
104105
CommitOutput(out_data[concat_enum::kOut], out_mem);
105106
MKLDNNStream::Get()->Submit();
106107
}
@@ -126,4 +127,4 @@ NNVM_REGISTER_OP(_contrib_quantized_concat)
126127
} // namespace op
127128
} // namespace mxnet
128129

129-
#endif // MXNET_USE_MKLDNN == 1
130+
#endif // MXNET_USE_MKLDNN == 100

0 commit comments

Comments
 (0)