@@ -146,24 +146,24 @@ Status ScatterND::ComputeInternal(ComputeContext& context) const {
146146 const auto * updates = context.Input <Tensor>(2 );
147147 const auto & input_shape = input->Shape ();
148148 const auto & indices_shape = indices->Shape ();
149- auto indices_rank = indices_shape.NumDimensions ();
150- auto last_index_dimension = static_cast <uint32_t >(indices_shape[indices_rank - 1 ]);
151- auto num_updates_elements = static_cast <uint32_t >(input_shape.SizeFromDimension (last_index_dimension));
152- // TODO: support bool with components 4.
153- const size_t components = 1 ;
154- auto output_size = static_cast <uint32_t >((indices_shape.SizeToDimension (indices_rank - 1 ) + components - 1 ) / components);
155149 auto * output = context.Output (0 , input_shape);
156- if (output_size == 0 ) {
157- // If the output tensor is empty, we can return early.
158- return Status::OK ();
159- }
160- MLDataType data_type = input->DataType ();
161150 const void * source = input->DataRaw ();
162151 void * target = output->MutableDataRaw ();
163152 // If source and target pointers are not equal (non-inplace operation), we need to copy the data.
164153 if (target != source) {
165154 ORT_RETURN_IF_ERROR (Info ().GetDataTransferManager ().CopyTensor (*input, *output));
166155 }
156+ if (indices_shape.Size () == 0 ) {
157+ // If the indices are empty, we can return early.
158+ return Status::OK ();
159+ }
160+ auto indices_rank = indices_shape.NumDimensions ();
161+ auto last_index_dimension = static_cast <uint32_t >(indices_shape[indices_rank - 1 ]);
162+ auto num_updates_elements = static_cast <uint32_t >(input_shape.SizeFromDimension (last_index_dimension));
163+ // TODO: support bool with components 4.
164+ const size_t components = 1 ;
165+ auto output_size = static_cast <uint32_t >((indices_shape.SizeToDimension (indices_rank - 1 ) + components - 1 ) / components);
166+ MLDataType data_type = input->DataType ();
167167 ScatterNDProgram program (reduction_, data_type);
168168 program
169169 .CacheHint (static_cast <uint32_t >(reduction_))
0 commit comments