Skip to content

Commit dcc7c4a

Browse files
fix batch audio return type
1 parent f2de9e3 commit dcc7c4a

File tree

2 files changed

+26
-29
lines changed

2 files changed

+26
-29
lines changed

faster_whisper/transcribe.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ def transcribe(
298298
language_detection_segments: int = 1,
299299
) -> Union[
300300
Tuple[Iterable[Segment], TranscriptionInfo],
301-
List[Tuple[List[Segment], TranscriptionInfo]],
301+
Tuple[List[List[Segment]], TranscriptionInfo],
302302
]:
303303
"""Transcribe audio in chunks in batched fashion and return with language info.
304304
@@ -379,9 +379,9 @@ def transcribe(
379379
- a generator over transcribed segments
380380
- an instance of TranscriptionInfo
381381
382-
For multiple audios: A list of tuples, each containing:
383-
- a list of transcribed segments
384-
- an instance of TranscriptionInfo
382+
For multiple audios: A tuple with:
383+
- a list of segment lists (one per audio)
384+
- an instance of TranscriptionInfo (using first audio's duration)
385385
"""
386386

387387
is_batch = isinstance(audio, list)
@@ -595,7 +595,7 @@ def transcribe(
595595
clip_timestamps_provided = clip_timestamps is not None
596596

597597
if is_batch:
598-
grouped_segments = self._batched_segments_generator_grouped(
598+
segments = self._batched_segments_generator_grouped(
599599
all_features,
600600
tokenizer,
601601
all_chunks_metadata,
@@ -605,20 +605,17 @@ def transcribe(
605605
log_progress,
606606
)
607607

608-
results = []
609-
for i, audio_segments in enumerate(grouped_segments):
610-
info = TranscriptionInfo(
611-
language=language,
612-
language_probability=language_probability,
613-
duration=audio_infos[i]["duration"],
614-
duration_after_vad=audio_infos[i]["duration_after_vad"],
615-
transcription_options=options,
616-
vad_options=_vad_parameters,
617-
all_language_probs=all_language_probs,
618-
)
619-
results.append((audio_segments, info))
608+
info = TranscriptionInfo(
609+
language=language,
610+
language_probability=language_probability,
611+
duration=audio_infos[0]["duration"],
612+
duration_after_vad=audio_infos[0]["duration_after_vad"],
613+
transcription_options=options,
614+
vad_options=_vad_parameters,
615+
all_language_probs=all_language_probs,
616+
)
620617

621-
return results
618+
return segments, info
622619
else:
623620
info = TranscriptionInfo(
624621
language=language,

tests/test_transcribe.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -308,19 +308,18 @@ def test_transcribe_multiple_audios(jfk_path):
308308
model = WhisperModel("tiny")
309309
batched_model = BatchedInferencePipeline(model=model)
310310

311-
results = batched_model.transcribe(
311+
all_segments, info = batched_model.transcribe(
312312
[jfk_path, jfk_path, jfk_path],
313313
batch_size=8,
314314
)
315315

316-
assert isinstance(results, list)
317-
assert len(results) == 3
318-
319-
for segments, info in results:
320-
assert info.language == "en"
321-
assert info.language_probability > 0.7
322-
assert info.duration == 11
316+
assert isinstance(all_segments, list)
317+
assert len(all_segments) == 3
318+
assert info.language == "en"
319+
assert info.language_probability > 0.7
320+
assert info.duration == 11
323321

322+
for segments in all_segments:
324323
assert isinstance(segments, list)
325324
assert len(segments) >= 1
326325

@@ -339,16 +338,17 @@ def test_transcribe_multiple_audios_with_word_timestamps(jfk_path):
339338
model = WhisperModel("tiny")
340339
batched_model = BatchedInferencePipeline(model=model)
341340

342-
results = batched_model.transcribe(
341+
all_segments, info = batched_model.transcribe(
343342
[jfk_path, jfk_path],
344343
batch_size=8,
345344
word_timestamps=True,
346345
without_timestamps=False,
347346
)
348347

349-
assert len(results) == 2
348+
assert len(all_segments) == 2
349+
assert info.language == "en"
350350

351-
for segments, info in results:
351+
for segments in all_segments:
352352
assert isinstance(segments, list)
353353
for segment in segments:
354354
assert segment.words is not None

0 commit comments

Comments
 (0)