Skip to content

Commit d293285

Browse files
authored
[QNN EP] Fix pool with reshape name conflicts (#25332)
Naming conflicts when expand-pool2d-squeeze (implemented as reshape) logic is invoked during ONNX -> QNN op lowering. Model with multiple pool 1D ops would hit this issue.
1 parent 14e0ad7 commit d293285

File tree

1 file changed

+20
-22
lines changed

1 file changed

+20
-22
lines changed

onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra
235235
ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(reshape_input, reshape_input_info));
236236

237237
bool needs_reshape = false;
238-
const std::string reshape4d = input_names[0] + "_pre_reshape";
238+
const std::string reshape_prior_out = input_names[0] + "_prior_reshape";
239239
if (input_shape.size() == 3) {
240240
needs_reshape = true;
241241
// build new_shape = {N, 1, C, L}
@@ -245,25 +245,24 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra
245245
input_shape[1],
246246
input_shape[2]};
247247

248-
const std::string reshape_node_name = "pre_reshape";
249-
QnnTensorWrapper rw(
250-
reshape4d,
248+
QnnTensorWrapper reshape_prior_tensor(
249+
reshape_prior_out,
251250
QNN_TENSOR_TYPE_NATIVE,
252251
reshape_input_info.qnn_data_type,
253252
reshape_input_info.quant_param.Copy(),
254253
std::move(new_shape));
255-
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(rw)),
256-
"Failed to add reshape-4d tensor.");
254+
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(reshape_prior_tensor)),
255+
"Failed to add reshape prior tensor.");
257256
ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(
258-
reshape_node_name,
257+
utils::GetNodeName(node_unit) + "_reshape_prior",
259258
QNN_OP_PACKAGE_NAME_QTI_AISW,
260-
"Reshape",
259+
QNN_OP_RESHAPE,
261260
{input_names[0]},
262-
{reshape4d},
261+
{reshape_prior_out},
263262
{},
264263
do_op_validation),
265-
"Failed to create reshape-4d node.");
266-
input_names[0] = reshape4d;
264+
"Failed to create reshape prior node for pool op.");
265+
input_names[0] = reshape_prior_out;
267266
input_shape = {input_shape[0], 1, input_shape[1], input_shape[2]};
268267
}
269268

@@ -446,9 +445,7 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra
446445
}
447446
const auto& outputs = node_unit.Outputs();
448447
const std::string real_out = outputs[0].node_arg.Name();
449-
const std::string pool_name = "poolmax2d";
450-
const std::string pool_out = real_out + "_post_reshape";
451-
const std::string post_reshape_node_name = "post_reshape";
448+
const std::string pool_out = real_out + "_reshape_after";
452449
const std::string qnn_op = GetQnnOpType(op_type);
453450
TensorInfo output_info{};
454451
ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(node_unit.Outputs()[0], output_info));
@@ -466,33 +463,34 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra
466463
"Failed to add tensor for pool_out");
467464

468465
ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(
469-
pool_name,
466+
utils::GetNodeName(node_unit) + "_pool2d",
470467
QNN_OP_PACKAGE_NAME_QTI_AISW,
471468
qnn_op,
472-
{reshape4d},
469+
{reshape_prior_out},
473470
{pool_out},
474471
std::move(param_tensor_names),
475472
do_op_validation),
476-
"Failed to create QNN Pool node for rank-3 input.");
473+
"Failed to create pool node for rank-3 input.");
477474

478475
std::vector<uint32_t> final_shape3d = output_info.shape;
479-
QnnTensorWrapper reshape_back_tensor(
476+
QnnTensorWrapper reshape_after_tensor(
480477
real_out,
481478
tensor_type,
482479
output_info.qnn_data_type,
483480
output_info.quant_param.Copy(),
484481
std::move(final_shape3d));
485482

486-
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(reshape_back_tensor)), "Failed to add tensor.");
483+
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(reshape_after_tensor)),
484+
"Failed to add reshape after tensor.");
487485
ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(
488-
post_reshape_node_name,
486+
utils::GetNodeName(node_unit) + "_reshape_after",
489487
QNN_OP_PACKAGE_NAME_QTI_AISW,
490-
"Reshape",
488+
QNN_OP_RESHAPE,
491489
{pool_out},
492490
{real_out},
493491
{},
494492
do_op_validation),
495-
"Failed to create reshape-back node.");
493+
"Failed to create reshape after node for pool op.");
496494

497495
return Status::OK();
498496
}

0 commit comments

Comments
 (0)