Skip to content

Commit 648ff93

Browse files
Adding prefix option for transcribe to allow multiple parallel runs
Signed-off-by: Nune <[email protected]>
1 parent 87dcb8d commit 648ff93

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

examples/asr/run_write_transcribed_files.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def create_transcribed_shard_manifests(prediction_filepaths: List[str], prefix)
4444
for prediction_filepath in prediction_filepaths:
4545
max_shard_id = 0
4646
shard_data = {}
47-
full_path = os.path.join(prediction_filepath, "predictions_all.json")
47+
full_path = os.path.join(prediction_filepath, f"{prefix}_predictions_all.json")
4848
with open(full_path, 'r') as f:
4949
for line in f.readlines():
5050
data_entry = json.loads(line)
@@ -90,7 +90,7 @@ def create_transcribed_manifests(prediction_filepaths: List[str], prefix) -> Lis
9090
"""
9191
all_manifest_filepaths = []
9292
for prediction_filepath in prediction_filepaths:
93-
prediction_name = os.path.join(prediction_filepath, "predictions_all.json")
93+
prediction_name = os.path.join(prediction_filepath, f"{prefix}_predictions_all.json")
9494
transcripted_name = os.path.join(prediction_filepath, f"{prefix}_transcribed_manifest.json")
9595

9696
# Open and read the original predictions_all.json file
@@ -129,7 +129,7 @@ def write_sampled_shard_transcriptions(manifest_filepaths: List[str], prefix) ->
129129
for prediction_filepath in manifest_filepaths:
130130
predicted_shard_data = {}
131131
# Collect entries from prediction files based on shard id
132-
prediction_path = os.path.join(prediction_filepath, "predictions_all.json")
132+
prediction_path = os.path.join(prediction_filepath, f"{prefix}_predictions_all.json")
133133
with open(prediction_path, 'r') as f:
134134
for line in f:
135135
data_entry = json.loads(line)
@@ -190,7 +190,7 @@ def write_sampled_transcriptions(manifest_filepaths: List[str], prefix) -> List[
190190
all_manifest_filepaths = []
191191
for prediction_filepath in manifest_filepaths:
192192
predicted_data = {}
193-
prediction_path = os.path.join(prediction_filepath, "predictions_all.json")
193+
prediction_path = os.path.join(prediction_filepath, f"{prefix}_predictions_all.json")
194194
with open(prediction_path, 'r') as f:
195195
for line in f:
196196
data_entry = json.loads(line)

examples/asr/transcribe_speech_parallel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def main(cfg: ParallelTranscriptionConfig):
191191
os.makedirs(cfg.output_path, exist_ok=True)
192192
# trainer.global_rank is not valid before predict() is called. Need this hack to find the correct global_rank.
193193
global_rank = trainer.node_rank * trainer.num_devices + int(os.environ.get("LOCAL_RANK", 0))
194-
output_file = os.path.join(cfg.output_path, f"predictions_{global_rank}.json")
194+
output_file = os.path.join(cfg.output_path, f"{cfg.predict_ds.prefix}_predictions_{global_rank}.json")
195195
predictor_writer = ASRPredictionWriter(dataset=data_loader.dataset, output_file=output_file)
196196
trainer.callbacks.extend([predictor_writer])
197197

@@ -211,7 +211,7 @@ def main(cfg: ParallelTranscriptionConfig):
211211
pred_text_list = []
212212
text_list = []
213213
if is_global_rank_zero():
214-
output_file = os.path.join(cfg.output_path, f"predictions_all.json")
214+
output_file = os.path.join(cfg.output_path, f"{cfg.predict_ds.prefix}_predictions_all.json")
215215
logging.info(f"Prediction files are being aggregated in {output_file}.")
216216
with open(output_file, 'w') as outf:
217217
for rank in range(trainer.world_size):

0 commit comments

Comments
 (0)