Skip to content

Commit bc0256f

Browse files
authored
[webgpu] Move the early return after copying for ScatterND (#25345)
### Description For ScatterND, if the indices are empty (nothing to update), it becomes a copy operation. So we should move the early return after copying.
1 parent b9b7530 commit bc0256f

File tree

2 files changed

+22
-11
lines changed

2 files changed

+22
-11
lines changed

onnxruntime/core/providers/webgpu/tensor/scatter_nd.cc

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -146,24 +146,24 @@ Status ScatterND::ComputeInternal(ComputeContext& context) const {
146146
const auto* updates = context.Input<Tensor>(2);
147147
const auto& input_shape = input->Shape();
148148
const auto& indices_shape = indices->Shape();
149-
auto indices_rank = indices_shape.NumDimensions();
150-
auto last_index_dimension = static_cast<uint32_t>(indices_shape[indices_rank - 1]);
151-
auto num_updates_elements = static_cast<uint32_t>(input_shape.SizeFromDimension(last_index_dimension));
152-
// TODO: support bool with components 4.
153-
const size_t components = 1;
154-
auto output_size = static_cast<uint32_t>((indices_shape.SizeToDimension(indices_rank - 1) + components - 1) / components);
155149
auto* output = context.Output(0, input_shape);
156-
if (output_size == 0) {
157-
// If the output tensor is empty, we can return early.
158-
return Status::OK();
159-
}
160-
MLDataType data_type = input->DataType();
161150
const void* source = input->DataRaw();
162151
void* target = output->MutableDataRaw();
163152
// If source and target pointers are not equal (non-inplace operation), we need to copy the data.
164153
if (target != source) {
165154
ORT_RETURN_IF_ERROR(Info().GetDataTransferManager().CopyTensor(*input, *output));
166155
}
156+
if (indices_shape.Size() == 0) {
157+
// If the indices are empty, we can return early.
158+
return Status::OK();
159+
}
160+
auto indices_rank = indices_shape.NumDimensions();
161+
auto last_index_dimension = static_cast<uint32_t>(indices_shape[indices_rank - 1]);
162+
auto num_updates_elements = static_cast<uint32_t>(input_shape.SizeFromDimension(last_index_dimension));
163+
// TODO: support bool with components 4.
164+
const size_t components = 1;
165+
auto output_size = static_cast<uint32_t>((indices_shape.SizeToDimension(indices_rank - 1) + components - 1) / components);
166+
MLDataType data_type = input->DataType();
167167
ScatterNDProgram program(reduction_, data_type);
168168
program
169169
.CacheHint(static_cast<uint32_t>(reduction_))

onnxruntime/test/providers/cpu/tensor/scatter_nd_op_test.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,5 +235,16 @@ TEST(ScatterNDOpTest, ScatterND_18_max) {
235235
test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
236236
}
237237

238+
// Test for ScatterND with empty indices - output should be same as input
239+
TEST(ScatterNDOpTest, ScatterND_empty_indices) {
240+
// Test with float data type and minimal empty case
241+
OpTester test1("ScatterND", 11);
242+
test1.AddInput<float>("data", {2, 3}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f});
243+
test1.AddInput<int64_t>("indices", {0, 1}, {}); // Empty indices tensor - no indices to process
244+
test1.AddInput<float>("updates", {0, 3}, {}); // Empty updates tensor
245+
test1.AddOutput<float>("output", {2, 3}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); // Same as input
246+
test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kDmlExecutionProvider});
247+
}
248+
238249
} // namespace test
239250
} // namespace onnxruntime

0 commit comments

Comments
 (0)