Skip to content

Commit 0fbb209

Browse files
petermcaughanPeter McAughan
andcommitted
Whisper Crash Fix (#19345)
### Description There is a current bug in the BeamSearch implementation of T5, GPT, and Whisper due to an interaction between two PRs merged in the past 7 months. First PR/code change is the addition of BeamSearchScorer GPU implementation. This PR accelerates some operations by executing them in the GPU and not the CPU. The approach for this code change didn't utilize a cudaStream when copying one particular variable from GPU to CPU (see nullptr value here: [[link](https://github.com/microsoft/onnxruntime/blob/b65d3d0a5374daa3bc9272c2c02763a8428660db/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h#L213)]). The second PR/code change was the alteration to utilize a cudaStream to initialize various memory buffers in BeamSearch (see `stream` included as the last argument in these allocations [[link](https://github.com/microsoft/onnxruntime/blob/d1431e1b78fb81bf90fdc58c9118cb011171f387/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h#L25)]). During the in-between period of these two PRs, I believe neither allocation utilized a stream and were thus synchronized. Once the latter PR was merged, the copy became desynchronized with the initialization due to different streams. The fix for this is to reintroduce the same stream into the copy operation added in the first PR. ### Motivation and Context This does not happen reliably on every hardware with every script due to the race condition nature, but the bug completely breaks ORT execution with a BeamSearch model. --------- Co-authored-by: Peter McAughan <petermca@microsoft.com>
1 parent 762703e commit 0fbb209

File tree

3 files changed

+3
-3
lines changed

3 files changed

+3
-3
lines changed

onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ Status BeamSearchGpt<T>::Execute(const FeedsFetchesManager* init_run_feeds_fetch
258258
cpu_state.sequences.InitDevice(beam_state.sequences_device);
259259
ORT_RETURN_IF_ERROR(this->device_copy_int32_func_(beam_state.sequences_device.subspan(0, beam_state.sequences_device.size() / 2),
260260
cpu_state.sequences_space.subspan(0, cpu_state.sequences_space.size() / 2),
261-
nullptr,
261+
this->ort_stream_,
262262
DeviceCopyDirection::hostToDevice));
263263
}
264264

onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ Status BeamSearchT5<T>::Execute(const FeedsFetchesManager& encoder_feeds_fetches
214214
cpu_state.sequences.InitDevice(beam_state.sequences_device);
215215
ORT_RETURN_IF_ERROR(this->device_copy_int32_func_(beam_state.sequences_device.subspan(0, beam_state.sequences_device.size() / 2),
216216
cpu_state.sequences_space.subspan(0, cpu_state.sequences_space.size() / 2),
217-
nullptr,
217+
this->ort_stream_,
218218
DeviceCopyDirection::hostToDevice));
219219
}
220220

onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ Status BeamSearchWhisper<T>::Execute(const FeedsFetchesManager& encoder_feeds_fe
226226
cpu_state.sequences.InitDevice(beam_state.sequences_device);
227227
ORT_RETURN_IF_ERROR(this->device_copy_int32_func_(beam_state.sequences_device.subspan(0, beam_state.sequences_device.size() / 2),
228228
cpu_state.sequences_space.subspan(0, cpu_state.sequences_space.size() / 2),
229-
nullptr,
229+
this->ort_stream_,
230230
DeviceCopyDirection::hostToDevice));
231231
}
232232

0 commit comments

Comments
 (0)