Skip to content

Commit ff81567

Browse files
authored
Added creation of QDQ for TopK node (#25309)
- Added TopK in registry.py so as to create QDQ nodes for the op - Ensure that both the input and output quantization params are equal - Added unit test to verify the creation of QDQ nodes for TopK ### Description: Added support for creation of QDQ nodes for TopK when quantized with ORT static quantization tool ### Motivation and Context: Currently there is support to form a node unit for TopK operator when QDQ nodes are present and both the input and output quantization params are equal. But there was no support to create QDQ nodes for TopK operator in the ORT static quantization tool
1 parent d293285 commit ff81567

File tree

2 files changed

+104
-0
lines changed

2 files changed

+104
-0
lines changed

onnxruntime/python/tools/quantization/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@
8686
"InstanceNormalization": QDQNormalization,
8787
"LayerNormalization": QDQNormalization,
8888
"BatchNormalization": QDQNormalization,
89+
"TopK": QDQDirect8BitOp,
8990
}
9091

9192

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# -------------------------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# Licensed under the MIT License. See License.txt in the project root for
4+
# license information.
5+
# --------------------------------------------------------------------------
6+
7+
import unittest
8+
9+
import numpy as np
10+
from onnx import TensorProto, helper, save
11+
from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count, check_qtype_by_node_type
12+
13+
from onnxruntime.quantization import QuantFormat, QuantType, quantize_static
14+
15+
16+
class TestTopKModel(unittest.TestCase):
17+
@staticmethod
18+
def construct_model(model_path, input_shape, axis_attr, k):
19+
input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, input_shape)
20+
k_tensor = helper.make_tensor("k", TensorProto.INT64, [1], [k])
21+
output_shape = input_shape[:]
22+
output_shape[axis_attr] = k
23+
output_values = helper.make_tensor_value_info("values", TensorProto.FLOAT, [1, k])
24+
output_indices = helper.make_tensor_value_info("indices", TensorProto.INT64, [1, k])
25+
26+
node = helper.make_node(
27+
"TopK", inputs=["input", "k"], outputs=["values", "indices"], name="topk_node", axis=axis_attr
28+
)
29+
30+
graph = helper.make_graph(
31+
[node],
32+
"quant_topk_op_test",
33+
[input_tensor],
34+
[output_values, output_indices],
35+
initializer=[k_tensor],
36+
)
37+
38+
model = helper.make_model(
39+
graph, opset_imports=[helper.make_opsetid("", 16), helper.make_opsetid("com.microsoft", 1)]
40+
)
41+
save(model, model_path)
42+
43+
def quantize_topk_test(self, activation_type, weight_type, extra_options={}): # noqa: B006
44+
model_fp32_path = "topk_fp32.onnx"
45+
input_shape = [1, 10]
46+
axis = 1
47+
k = 3
48+
self.construct_model(model_fp32_path, input_shape, axis, k)
49+
50+
input_data_list = [
51+
{"input": np.array([[1.8, 2.5, -5.9, 5.2, 4.1, 7.3, 0.2, -0.5, 0.845, 3.9]], dtype=np.float32)}
52+
]
53+
data_reader = TestDataFeeds(input_data_list)
54+
55+
activation_proto_qtype = TensorProto.UINT8 if activation_type == QuantType.QUInt8 else TensorProto.INT8
56+
activation_type_str = "u8" if (activation_type == QuantType.QUInt8) else "s8"
57+
weight_type_str = "u8" if (weight_type == QuantType.QUInt8) else "s8"
58+
model_qdq_path = f"topk_{activation_type_str}{weight_type_str}_{'QNoInCk' if extra_options['ForceQuantizeNoInputCheck'] else 'NoQNoInCk'}_qdq.onnx"
59+
60+
# Verify QDQ mode
61+
data_reader.rewind()
62+
quantize_static(
63+
model_fp32_path,
64+
model_qdq_path,
65+
data_reader,
66+
quant_format=QuantFormat.QDQ,
67+
activation_type=activation_type,
68+
weight_type=weight_type,
69+
extra_options=extra_options,
70+
)
71+
qdqnode_counts = (
72+
{
73+
"TopK": 1,
74+
"QuantizeLinear": 2,
75+
"DequantizeLinear": 2,
76+
}
77+
if extra_options["ForceQuantizeNoInputCheck"]
78+
else {
79+
"TopK": 1,
80+
"QuantizeLinear": 0,
81+
"DequantizeLinear": 0,
82+
}
83+
)
84+
check_op_type_count(self, model_qdq_path, **qdqnode_counts)
85+
qnode_io_qtypes = {
86+
"QuantizeLinear": [
87+
["i", 2, activation_proto_qtype],
88+
["o", 0, activation_proto_qtype],
89+
]
90+
}
91+
check_qtype_by_node_type(self, model_qdq_path, qnode_io_qtypes)
92+
data_reader.rewind()
93+
check_model_correctness(self, model_fp32_path, model_qdq_path, data_reader.get_next())
94+
95+
def test_quantize_topk_u8u8(self):
96+
self.quantize_topk_test(QuantType.QUInt8, QuantType.QUInt8, extra_options={"ForceQuantizeNoInputCheck": True})
97+
98+
def test_quantize_topk_u8u8_no_force_quantize_no_input_check(self):
99+
self.quantize_topk_test(QuantType.QUInt8, QuantType.QUInt8, extra_options={"ForceQuantizeNoInputCheck": False})
100+
101+
102+
if __name__ == "__main__":
103+
unittest.main()

0 commit comments

Comments
 (0)