Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit f8eeab7

Browse files
ciyongchnswamy
authored andcommitted
MKLDNN based Quantized FullyConnected Operator and its fusion (#14128)
* add MKL-DNN quantized innerproduct * initial qfc with mkldnn * Add MKL-DNN quantized_fully_connected * refactor params order for fullyconnected * update quantized_fully_connected unittest, force data to uint8 type temporary * change mkl based quantized fully_connected to FCompute * add check data type for mkldnn quantized_fc * add fuse requantize and dequantize for mkldnn quantized fullyconnected * add env setting for enable/disable fuse requantize/dequantize for quantize fullyconnected * fix requantize scaling error * add fallback when input data is int8 * fix mkl quantized fullyconnected index error * update quantized fc test cases * add subgraph node for mkldnn fullyconnected * fix compiling and lint error * clean and refactor code * enable quantized_fc for imagenet * cleanup code * Fix StorageType error for non-mkldnn path * fix pylint * reverse BUILD_TAG for MKL IGEMM ut, remove IGEMM qfc check * rename variables and refactor codes according to comments * add subgraph qfc tests and fix shape error * remove fuse_requantize and change fuse_dequantize to enable_float_output. * change to use mxnet::Tuple and update tests * update description in file header * update input0 type check for quantized FullyConnected * fix conflit of mkl/test_subgraph.py * retrigger CI * retrigger CI due to hang
1 parent 1bb78eb commit f8eeab7

File tree

12 files changed

+1679
-224
lines changed

12 files changed

+1679
-224
lines changed

example/quantization/imagenet_gen_qsym_mkldnn.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ def save_params(fname, arg_params, aux_params, logger=None):
180180
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
181181

182182
sym = sym.get_backend_symbol('MKLDNN')
183+
sym = sym.get_backend_symbol('MKLDNN_FC')
183184

184185
# get batch size
185186
batch_size = args.batch_size
@@ -207,19 +208,18 @@ def save_params(fname, arg_params, aux_params, logger=None):
207208
if args.model == 'imagenet1k-resnet-152':
208209
rgb_mean = '0,0,0'
209210
rgb_std = '1,1,1'
210-
excluded_sym_names += ['flatten0', 'fc1']
211+
excluded_sym_names += ['flatten0']
211212
if exclude_first_conv:
212213
excluded_sym_names += ['conv0']
213214
elif args.model == 'imagenet1k-inception-bn':
214215
rgb_mean = '123.68,116.779,103.939'
215216
rgb_std = '1,1,1'
216-
excluded_sym_names += ['flatten', 'fc1']
217+
excluded_sym_names += ['flatten']
217218
if exclude_first_conv:
218219
excluded_sym_names += ['conv_1']
219220
elif args.model in ['resnet50_v1', 'resnet101_v1']:
220221
rgb_mean = '123.68,116.779,103.939'
221222
rgb_std = '58.393, 57.12, 57.375'
222-
excluded_sym_names += ['resnetv10_dense0_fwd']
223223
if exclude_first_conv:
224224
excluded_sym_names += ['resnetv10_conv0_fwd']
225225
elif args.model == 'squeezenet1.0':
@@ -232,14 +232,12 @@ def save_params(fname, arg_params, aux_params, logger=None):
232232
rgb_mean = '123.68,116.779,103.939'
233233
rgb_std = '58.393, 57.12, 57.375'
234234
excluded_sym_names += ['mobilenet0_flatten0_flatten0',
235-
'mobilenet0_dense0_fwd',
236235
'mobilenet0_pool0_fwd']
237236
if exclude_first_conv:
238237
excluded_sym_names += ['mobilenet0_conv0_fwd']
239238
elif args.model == 'inceptionv3':
240239
rgb_mean = '123.68,116.779,103.939'
241240
rgb_std = '58.393, 57.12, 57.375'
242-
excluded_sym_names += ['inception30_dense0_fwd']
243241
if exclude_first_conv:
244242
excluded_sym_names += ['inception30_conv0_fwd']
245243
elif args.model == 'custom':
@@ -305,6 +303,7 @@ def save_params(fname, arg_params, aux_params, logger=None):
305303
% calib_mode)
306304
sym_name = '%s-symbol.json' % (prefix + suffix)
307305
qsym = qsym.get_backend_symbol('MKLDNN_POST_QUANTIZE')
306+
qsym = qsym.get_backend_symbol('MKLDNN_POST_FC_QUANTIZE')
308307
save_symbol(sym_name, qsym, logger)
309308
param_name = '%s-%04d.params' % (prefix + '-quantized', epoch)
310309
save_params(param_name, qarg_params, aux_params, logger)

python/mxnet/initializer.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,12 @@ def __call__(self, desc, arr):
159159
elif desc.endswith('max'):
160160
self._init_one(desc, arr)
161161
self._verbose_print(desc, 'max', arr)
162+
elif desc.endswith('weight_quantize'):
163+
self._init_quantized_weight(desc, arr)
164+
self._verbose_print(desc, 'weight_quantize', arr)
165+
elif desc.endswith('bias_quantize'):
166+
self._init_quantized_bias(desc, arr)
167+
self._verbose_print(desc, 'bias_quantize', arr)
162168
else:
163169
self._init_default(desc, arr)
164170

@@ -235,6 +241,9 @@ def _init_one(self, _, arr):
235241
def _init_bias(self, _, arr):
236242
arr[:] = 0.0
237243

244+
def _init_quantized_bias(self, _, arr):
245+
arr[:] = 0
246+
238247
def _init_gamma(self, _, arr):
239248
arr[:] = 1.0
240249

@@ -245,6 +254,10 @@ def _init_weight(self, name, arr):
245254
"""Abstract method to Initialize weight."""
246255
raise NotImplementedError("Must override it")
247256

257+
def _init_quantized_weight(self, _, arr):
258+
_arr = random.randint(-127, 127, dtype='int32').asnumpy()
259+
arr[:] = np.int8(_arr)
260+
248261
def _init_default(self, name, _):
249262
raise ValueError(
250263
'Unknown initialization pattern for %s. ' \
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* Copyright (c) 2019 by Contributors
22+
* \file mkldnn_fully_connected-inl.h
23+
* \brief Common functions used by MKLDNN (Quantized) FullyConnected operator
24+
* \author Ciyong Chen
25+
*/
26+
27+
#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_FULLY_CONNECTED_INL_H_
28+
#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_FULLY_CONNECTED_INL_H_
29+
30+
#if MXNET_USE_MKLDNN == 1
31+
32+
#include <vector>
33+
#include <string>
34+
#include "../fully_connected-inl.h"
35+
#include "./mkldnn_base-inl.h"
36+
37+
namespace mxnet {
38+
namespace op {
39+
40+
struct MKLDNNFCParam: public dmlc::Parameter<MKLDNNFCParam> {
41+
bool quantized;
42+
bool enable_float_output;
43+
bool with_relu;
44+
dmlc::optional<float> min_calib_range; // min float value calculated from calibration dataset
45+
dmlc::optional<float> max_calib_range; // max float value calculated from calibration dataset
46+
47+
DMLC_DECLARE_PARAMETER(MKLDNNFCParam) {
48+
DMLC_DECLARE_FIELD(quantized).set_default(false)
49+
.describe("Whether it's a quantized FullyConnected operator");
50+
DMLC_DECLARE_FIELD(enable_float_output).set_default(false)
51+
.describe("Whether to enable float32 output");
52+
DMLC_DECLARE_FIELD(with_relu).set_default(false)
53+
.describe("Whether there's a post relu after FullyConnected operator");
54+
DMLC_DECLARE_FIELD(min_calib_range)
55+
.set_default(dmlc::optional<float>())
56+
.describe("The minimum scalar value in the form of float32 obtained "
57+
"through calibration. If present, it will be used to by "
58+
"quantized fullyconnected op to calculate primitive scale");
59+
DMLC_DECLARE_FIELD(max_calib_range)
60+
.set_default(dmlc::optional<float>())
61+
.describe("The maximum scalar value in the form of float32 obtained "
62+
"through calibration. If present, it will be used to by "
63+
"quantized fullyconnected op to calculate primitive scale");
64+
}
65+
};
66+
67+
struct MKLDNNFCFullParam {
68+
FullyConnectedParam default_param;
69+
MKLDNNFCParam mkldnn_param;
70+
std::vector<float> output_scales = {0.0};
71+
std::vector<float> requantize_scales = {0.0};
72+
};
73+
74+
mkldnn::inner_product_forward::primitive_desc GetFCFwdImpl(
75+
const MKLDNNFCFullParam &full_param, const bool is_train,
76+
const NDArray &data, const NDArray &weight, const NDArray *bias,
77+
const mkldnn::memory::desc &out_md);
78+
79+
class MKLDNNFullyConnectedForward {
80+
public:
81+
mkldnn::inner_product_forward::primitive_desc fwd_pd;
82+
83+
MKLDNNFullyConnectedForward(const MKLDNNFCFullParam &full_param, const bool is_train,
84+
const NDArray &data, const NDArray &weight,
85+
const NDArray *bias,
86+
const mkldnn::memory::desc &out_md)
87+
: fwd_pd(GetFCFwdImpl(full_param, is_train, data, weight, bias, out_md)) {}
88+
89+
90+
void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &weight,
91+
const mkldnn::memory *bias, const mkldnn::memory &output);
92+
93+
const mkldnn::inner_product_forward &GetFwd() const {
94+
return *fwd_;
95+
}
96+
97+
private:
98+
std::shared_ptr<mkldnn::inner_product_forward> fwd_;
99+
std::shared_ptr<mkldnn::memory> data_;
100+
std::shared_ptr<mkldnn::memory> weight_;
101+
std::shared_ptr<mkldnn::memory> bias_;
102+
std::shared_ptr<mkldnn::memory> out_;
103+
};
104+
105+
typedef ParamOpSign<FullyConnectedParam> MKLDNNFullyconSignature;
106+
107+
MKLDNNFullyConnectedForward &GetFCFwd(
108+
const FullyConnectedParam &param, const bool is_train,
109+
const NDArray &data, const NDArray &weight,
110+
const NDArray *bias, const mkldnn::memory::desc &out_md);
111+
112+
void MKLDNNFCFlattenData(const FullyConnectedParam &param,
113+
const NDArray &out_data,
114+
NDArray *in_data,
115+
mkldnn::memory::desc *out_md);
116+
117+
void MKLDNNFCForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
118+
const std::vector<NDArray> &in_data,
119+
const std::vector<OpReqType> &req,
120+
const std::vector<NDArray> &out_data);
121+
122+
void MKLDNNFCForwardFullFeature(const MKLDNNFCFullParam &param,
123+
const OpContext &ctx,
124+
MKLDNNFullyConnectedForward *fwd,
125+
const std::vector<NDArray> &in_data,
126+
const std::vector<OpReqType> &req,
127+
const std::vector<NDArray> &out_data);
128+
129+
} // namespace op
130+
} // namespace mxnet
131+
132+
#endif // MXNET_USE_MKLDNN == 1
133+
#endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_FULLY_CONNECTED_INL_H_

0 commit comments

Comments
 (0)