Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 8d6ac4a

Browse files
TaoLvpengzhao-intel
authored andcommitted
Support 3D input for MKL-DNN softmax operator (#14818)
* add 3d softmax * fix * handle req type * clean code * remove check * check axis * retrigger ci
1 parent d87bd2a commit 8d6ac4a

File tree

3 files changed

+44
-21
lines changed

3 files changed

+44
-21
lines changed

src/operator/nn/mkldnn/mkldnn_base-inl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ bool SupportMKLDNNAct(const ActivationParam& param);
180180
bool SupportMKLDNNAct(const ActivationParam& param, const NDArray &input);
181181
bool SupportMKLDNNConv(const ConvolutionParam& params, const NDArray &input);
182182
bool SupportMKLDNNDeconv(const DeconvolutionParam& params, const NDArray &input);
183-
bool SupportMKLDNNSoftmax(const SoftmaxParam& param);
183+
bool SupportMKLDNNSoftmax(const SoftmaxParam& param, const NDArray &input, const NDArray &output);
184184
bool SupportMKLDNNSoftmaxOutput(const SoftmaxOutputParam &param);
185185
bool SupportMKLDNNTranspose(const TransposeParam& param, const NDArray &data);
186186
} // namespace op

src/operator/nn/mkldnn/mkldnn_softmax.cc

Lines changed: 42 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -26,43 +26,66 @@
2626
#include "../softmax-inl.h"
2727
#include "./mkldnn_ops-inl.h"
2828
#include "./mkldnn_base-inl.h"
29-
#include "../../tensor/broadcast_reduce_op.h"
3029

3130
#if MXNET_USE_MKLDNN == 1
3231
namespace mxnet {
3332
namespace op {
3433

35-
bool SupportMKLDNNSoftmax(const SoftmaxParam &param) {
34+
bool SupportMKLDNNSoftmax(const SoftmaxParam &param,
35+
const NDArray &data,
36+
const NDArray &output) {
37+
const int ndim = data.shape().ndim();
38+
const int in_dtype = data.dtype();
39+
const int out_dtype = output.dtype();
40+
41+
const int axis = CheckAxis(param.axis, ndim);
3642
// MKLDNN does not support temperature argument in their softmax function
3743
// now. Need update this once they start to support it.
38-
if (param.temperature.has_value()) {
44+
// Currently, MKLDNN shows bad performance when softmax is not performed on the last dimension
45+
if (param.temperature.has_value() ||
46+
in_dtype != mshadow::kFloat32 ||
47+
in_dtype != out_dtype ||
48+
axis != (ndim - 1)) {
3949
return false;
4050
}
41-
return true;
51+
// only supports ndim = 1, 2, 3, 4 for now
52+
return (ndim >= 1 && ndim <= 4);
53+
}
54+
55+
static mkldnn::softmax_forward::primitive_desc GetSoftmaxFwdPd(const int axis,
56+
const bool is_train,
57+
const mkldnn::memory &input) {
58+
auto data_md = input.get_primitive_desc().desc();
59+
auto prop = is_train ? mkldnn::prop_kind::forward_training : mkldnn::prop_kind::forward_scoring;
60+
auto desc = mkldnn::softmax_forward::desc(prop, data_md, axis);
61+
auto pd = mkldnn::softmax_forward::primitive_desc(desc, CpuEngine::Get()->get_engine());
62+
return pd;
4263
}
4364

44-
void MKLDNNSoftmaxForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
45-
const NDArray &in_data, const OpReqType &req,
65+
void MKLDNNSoftmaxForward(const nnvm::NodeAttrs &attrs,
66+
const OpContext &ctx,
67+
const NDArray &in_data,
68+
const OpReqType &req,
4669
const NDArray &out_data) {
70+
if (req == kNullOp) return;
71+
// same as the FCompute path, softmax only supports kWriteTo and kWriteInplace for now.
72+
CHECK_NE(req, kAddTo);
4773
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
48-
auto input_mem = in_data.GetMKLDNNData();
49-
mkldnn::memory::primitive_desc data_mpd = input_mem->get_primitive_desc();
50-
mkldnn::memory::desc data_md = data_mpd.desc();
51-
int axis = CheckAxis(param.axis, in_data.shape().ndim());
74+
const int axis = CheckAxis(param.axis, in_data.shape().ndim());
5275

53-
auto cpu_engine = data_mpd.get_engine();
54-
auto prop = ctx.is_train
55-
? mkldnn::prop_kind::forward_training : mkldnn::prop_kind::forward_scoring;
56-
mkldnn::softmax_forward::desc desc = mkldnn::softmax_forward::desc(prop,
57-
data_md, axis);
58-
mkldnn::softmax_forward::primitive_desc pdesc(desc, cpu_engine);
76+
NDArray data = in_data;
77+
if (in_data.IsView() && in_data.IsMKLDNNData()) {
78+
data = in_data.Reorder2Default();
79+
}
5980

60-
auto output_memory = out_data.GetMKLDNNData();
81+
auto data_mem = data.GetMKLDNNData();
82+
auto pd = GetSoftmaxFwdPd(axis, ctx.is_train, *data_mem);
83+
auto out_mem = CreateMKLDNNMem(out_data, pd.dst_primitive_desc(), req);
6184
MKLDNNStream *stream = MKLDNNStream::Get();
62-
stream->RegisterPrim(mkldnn::softmax_forward(pdesc, *input_mem, *output_memory));
85+
stream->RegisterPrim(mkldnn::softmax_forward(pd, *data_mem, *out_mem.second));
86+
CommitOutput(out_data, out_mem);
6387
stream->Submit();
6488
}
65-
6689
} // namespace op
6790
} // namespace mxnet
6891
#endif

src/operator/nn/softmax.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ static void SoftmaxComputeExCPU(const nnvm::NodeAttrs& attrs,
4343
const std::vector<NDArray>& outputs) {
4444
// It seems MKLDNN softmax doesn't support training.
4545
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
46-
if (SupportMKLDNN(inputs[0]) && !ctx.is_train && SupportMKLDNNSoftmax(param)) {
46+
if (SupportMKLDNNSoftmax(param, inputs[0], outputs[0])) {
4747
MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
4848
MKLDNNSoftmaxForward(attrs, ctx, inputs[0], req[0], outputs[0]);
4949
auto fn = SoftmaxCompute<cpu, mxnet_op::softmax_fwd>;

0 commit comments

Comments
 (0)