|
26 | 26 | #include "../softmax-inl.h" |
27 | 27 | #include "./mkldnn_ops-inl.h" |
28 | 28 | #include "./mkldnn_base-inl.h" |
29 | | -#include "../../tensor/broadcast_reduce_op.h" |
30 | 29 |
|
31 | 30 | #if MXNET_USE_MKLDNN == 1 |
32 | 31 | namespace mxnet { |
33 | 32 | namespace op { |
34 | 33 |
|
35 | | -bool SupportMKLDNNSoftmax(const SoftmaxParam ¶m) { |
| 34 | +bool SupportMKLDNNSoftmax(const SoftmaxParam ¶m, |
| 35 | + const NDArray &data, |
| 36 | + const NDArray &output) { |
| 37 | + const int ndim = data.shape().ndim(); |
| 38 | + const int in_dtype = data.dtype(); |
| 39 | + const int out_dtype = output.dtype(); |
| 40 | + |
| 41 | + const int axis = CheckAxis(param.axis, ndim); |
36 | 42 | // MKLDNN does not support temperature argument in their softmax function |
37 | 43 | // now. Need update this once they start to support it. |
38 | | - if (param.temperature.has_value()) { |
| 44 | + // Currently, MKLDNN shows bad performance when softmax is not performed on the last dimension |
| 45 | + if (param.temperature.has_value() || |
| 46 | + in_dtype != mshadow::kFloat32 || |
| 47 | + in_dtype != out_dtype || |
| 48 | + axis != (ndim - 1)) { |
39 | 49 | return false; |
40 | 50 | } |
41 | | - return true; |
| 51 | + // only supports ndim = 1, 2, 3, 4 for now |
| 52 | + return (ndim >= 1 && ndim <= 4); |
| 53 | +} |
| 54 | + |
| 55 | +static mkldnn::softmax_forward::primitive_desc GetSoftmaxFwdPd(const int axis, |
| 56 | + const bool is_train, |
| 57 | + const mkldnn::memory &input) { |
| 58 | + auto data_md = input.get_primitive_desc().desc(); |
| 59 | + auto prop = is_train ? mkldnn::prop_kind::forward_training : mkldnn::prop_kind::forward_scoring; |
| 60 | + auto desc = mkldnn::softmax_forward::desc(prop, data_md, axis); |
| 61 | + auto pd = mkldnn::softmax_forward::primitive_desc(desc, CpuEngine::Get()->get_engine()); |
| 62 | + return pd; |
42 | 63 | } |
43 | 64 |
|
44 | | -void MKLDNNSoftmaxForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, |
45 | | - const NDArray &in_data, const OpReqType &req, |
| 65 | +void MKLDNNSoftmaxForward(const nnvm::NodeAttrs &attrs, |
| 66 | + const OpContext &ctx, |
| 67 | + const NDArray &in_data, |
| 68 | + const OpReqType &req, |
46 | 69 | const NDArray &out_data) { |
| 70 | + if (req == kNullOp) return; |
| 71 | + // same as the FCompute path, softmax only supports kWriteTo and kWriteInplace for now. |
| 72 | + CHECK_NE(req, kAddTo); |
47 | 73 | const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed); |
48 | | - auto input_mem = in_data.GetMKLDNNData(); |
49 | | - mkldnn::memory::primitive_desc data_mpd = input_mem->get_primitive_desc(); |
50 | | - mkldnn::memory::desc data_md = data_mpd.desc(); |
51 | | - int axis = CheckAxis(param.axis, in_data.shape().ndim()); |
| 74 | + const int axis = CheckAxis(param.axis, in_data.shape().ndim()); |
52 | 75 |
|
53 | | - auto cpu_engine = data_mpd.get_engine(); |
54 | | - auto prop = ctx.is_train |
55 | | - ? mkldnn::prop_kind::forward_training : mkldnn::prop_kind::forward_scoring; |
56 | | - mkldnn::softmax_forward::desc desc = mkldnn::softmax_forward::desc(prop, |
57 | | - data_md, axis); |
58 | | - mkldnn::softmax_forward::primitive_desc pdesc(desc, cpu_engine); |
| 76 | + NDArray data = in_data; |
| 77 | + if (in_data.IsView() && in_data.IsMKLDNNData()) { |
| 78 | + data = in_data.Reorder2Default(); |
| 79 | + } |
59 | 80 |
|
60 | | - auto output_memory = out_data.GetMKLDNNData(); |
| 81 | + auto data_mem = data.GetMKLDNNData(); |
| 82 | + auto pd = GetSoftmaxFwdPd(axis, ctx.is_train, *data_mem); |
| 83 | + auto out_mem = CreateMKLDNNMem(out_data, pd.dst_primitive_desc(), req); |
61 | 84 | MKLDNNStream *stream = MKLDNNStream::Get(); |
62 | | - stream->RegisterPrim(mkldnn::softmax_forward(pdesc, *input_mem, *output_memory)); |
| 85 | + stream->RegisterPrim(mkldnn::softmax_forward(pd, *data_mem, *out_mem.second)); |
| 86 | + CommitOutput(out_data, out_mem); |
63 | 87 | stream->Submit(); |
64 | 88 | } |
65 | | - |
66 | 89 | } // namespace op |
67 | 90 | } // namespace mxnet |
68 | 91 | #endif |
0 commit comments