Skip to content

Commit 566e5b3

Browse files
titu1994pzelasko
authored andcommitted
Add support to change Multi task model prompt (#9542)
* Add support to change Multi task model prompt Signed-off-by: smajumdar <titu1994@gmail.com> * Add support to change Multi task model prompt Signed-off-by: smajumdar <titu1994@gmail.com> * Apply isort and black reformatting Signed-off-by: titu1994 <titu1994@users.noreply.github.com> * Update nemo/collections/common/prompts/formatter.py Co-authored-by: Piotr Żelasko <petezor@gmail.com> Signed-off-by: Somshubra Majumdar <titu1994@gmail.com> * Address comments Signed-off-by: smajumdar <titu1994@gmail.com> * Apply isort and black reformatting Signed-off-by: titu1994 <titu1994@users.noreply.github.com> * Address comments Signed-off-by: smajumdar <titu1994@gmail.com> --------- Signed-off-by: smajumdar <titu1994@gmail.com> Signed-off-by: titu1994 <titu1994@users.noreply.github.com> Signed-off-by: Somshubra Majumdar <titu1994@gmail.com> Co-authored-by: Piotr Żelasko <petezor@gmail.com> Signed-off-by: ashors1 <ashors@nvidia.com>
1 parent b676f68 commit 566e5b3

File tree

4 files changed

+131
-15
lines changed

4 files changed

+131
-15
lines changed

nemo/collections/asr/models/aed_multitask_models.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,14 @@
1414

1515
import os
1616
import warnings
17+
from collections.abc import Mapping, Sequence
1718
from dataclasses import dataclass, field
1819
from math import ceil
1920
from typing import Any, Dict, List, Optional, Union
2021

2122
import numpy as np
2223
import torch
23-
from omegaconf import DictConfig, OmegaConf, open_dict
24+
from omegaconf import DictConfig, ListConfig, OmegaConf, open_dict
2425
from pytorch_lightning import Trainer
2526
from torch.utils.data import DataLoader
2627

@@ -387,6 +388,59 @@ def change_vocabulary(
387388

388389
logging.info(f"Changed decoder to output to {vocabulary} vocabulary.")
389390

391+
def change_prompt(
392+
self, prompt_format: Optional[str] = None, prompt_defaults: Optional[List[Dict[str, Any]]] = None
393+
):
394+
"""
395+
Changes the prompt format used during Multi Task decoding process.
396+
397+
Args:
398+
prompt_format: A string alias of the object that represents the prompt structure.
399+
If not None, it will be used to update the prompt format.
400+
prompt_defaults: A dictionary of default values for the prompt format.
401+
"""
402+
if prompt_format is not None:
403+
self.prompt_format = prompt_format
404+
405+
if prompt_defaults is not None:
406+
# Perform some assertions on the prompt defaults contents
407+
# Must be a list-like object
408+
if not isinstance(prompt_defaults, Sequence):
409+
raise ValueError("`prompt_defaults` must be a list of dictionaries")
410+
411+
# Must contain dict-like objects
412+
for item in prompt_defaults:
413+
if not isinstance(item, Mapping):
414+
raise ValueError("`prompt_defaults` must be a list of dictionaries")
415+
416+
# Each dict item must have a `role` key
417+
if 'role' not in item:
418+
raise ValueError(
419+
"`prompt_defaults` must have a `role` key for each item in the list of dictionaries"
420+
)
421+
422+
if 'slots' not in item:
423+
raise ValueError(
424+
"`prompt_defaults` must have a `slots` key for each item in the list of dictionaries"
425+
)
426+
427+
# Cast to OmegaConf if not already
428+
if not isinstance(prompt_defaults, ListConfig):
429+
prompt_defaults = OmegaConf.create(prompt_defaults)
430+
431+
prompt_cls = PromptFormatter.resolve(self.prompt_format)
432+
self.prompt = prompt_cls(
433+
tokenizer=self.tokenizer,
434+
defaults=OmegaConf.to_container(pd) if (pd := self.cfg.prompt_defaults) is not None else None,
435+
)
436+
437+
# Update config
438+
with open_dict(self.cfg):
439+
self.cfg.prompt_format = self.prompt_format
440+
self.cfg.prompt_defaults = prompt_defaults
441+
442+
logging.info(f"Changed prompt format to `{self.prompt_format}`")
443+
390444
@torch.no_grad()
391445
def transcribe(
392446
self,

nemo/collections/common/prompts/canary.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ class CanaryPromptFormatter(PromptFormatter):
1616
"template": f"{CANARY_BOS}|source_lang||task||target_lang||pnc|",
1717
"slots": {
1818
"source_lang": Modality.Text,
19-
"task": Modality.Text,
19+
"task": Modality.TextLiteral("asr", "ast", "s2t_translation", "<|transcribe|>", "<|translate|>"),
2020
"target_lang": Modality.Text,
21-
"pnc": Modality.Text,
21+
"pnc": Modality.TextLiteral("yes", "no", "<|pnc|>", "<|nopnc|>"),
2222
},
2323
},
2424
OUTPUT_ROLE: {

nemo/collections/common/prompts/formatter.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,22 +20,38 @@
2020
EOS_SLOT = "|eos|"
2121

2222

23-
class Modality(Enum):
23+
class BaseModalityType:
24+
@staticmethod
25+
def matches(value: Any) -> bool:
26+
raise NotImplementedError
27+
28+
29+
class Text(BaseModalityType):
30+
"""Modality for text values."""
31+
32+
@staticmethod
33+
def matches(value: str) -> bool:
34+
return isinstance(value, str)
35+
36+
37+
class TextLiteral(BaseModalityType):
38+
def __init__(self, *items):
39+
self.allowed_values = items
40+
41+
def matches(self, value: str) -> bool:
42+
return isinstance(value, str) and value in self.allowed_values
43+
44+
def __repr__(self):
45+
return f"{self.__class__.__name__}({self.allowed_values})"
46+
47+
48+
class Modality:
2449
"""
2550
Modalities supported as PromptFormatter slot values.
2651
"""
2752

28-
Text = "text"
29-
30-
def matches(self, value: Any) -> bool:
31-
"""
32-
Checks if the provided value is compatible with an instance of Modality.
33-
"""
34-
match self:
35-
case Modality.Text:
36-
return isinstance(value, str)
37-
case _:
38-
return False
53+
Text = Text
54+
TextLiteral = TextLiteral
3955

4056

4157
class PromptFormatter(ABC):

tests/collections/asr/test_asr_multitask_model_bpe.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from nemo.collections.asr.models.aed_multitask_models import EncDecMultiTaskModel
2323
from nemo.collections.asr.parts.submodules import multitask_beam_decoding as beam_decode
2424
from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis
25+
from nemo.collections.common.prompts.canary import CanaryPromptFormatter
2526
from nemo.collections.common.tokenizers import CanaryTokenizer
2627

2728

@@ -275,6 +276,51 @@ def test_decoding_change(self, asr_model):
275276
assert isinstance(asr_model.decoding.decoding, beam_decode.TransformerAEDBeamInfer)
276277
assert asr_model.decoding.decoding.search_type == "default"
277278

279+
@pytest.mark.unit
280+
def test_prompt_change(self, asr_model):
281+
assert asr_model.prompt_format == 'canary'
282+
assert isinstance(asr_model.prompt, CanaryPromptFormatter)
283+
284+
# Default change prompt
285+
asr_model.change_prompt()
286+
assert asr_model.cfg.prompt_defaults is None
287+
288+
prompt_defaults = asr_model.prompt.get_default_dialog_slots()
289+
prompt_defaults[0]['slots']['pnc'] = 'no'
290+
asr_model.change_prompt(prompt_defaults=prompt_defaults)
291+
292+
assert asr_model.cfg.prompt_defaults[0]['slots']['pnc'] == 'no'
293+
294+
@pytest.mark.unit
295+
def test_prompt_change_subclass(self, asr_model):
296+
assert asr_model.prompt_format == 'canary'
297+
assert isinstance(asr_model.prompt, CanaryPromptFormatter)
298+
299+
class CanaryPromptFormatterSubclass(CanaryPromptFormatter):
300+
NAME = "canary2"
301+
302+
# Default change prompt
303+
asr_model.change_prompt()
304+
assert asr_model.cfg.prompt_defaults is None
305+
306+
prompt_defaults = asr_model.prompt.get_default_dialog_slots()
307+
prompt_defaults[0]['slots']['pnc'] = 'no'
308+
asr_model.change_prompt(prompt_format='canary2', prompt_defaults=prompt_defaults)
309+
310+
assert asr_model.cfg.prompt_format == 'canary2'
311+
assert asr_model.cfg.prompt_defaults[0]['slots']['pnc'] == 'no'
312+
assert isinstance(asr_model.prompt, CanaryPromptFormatterSubclass)
313+
314+
user_prompt = asr_model.prompt.get_default_dialog_slots()[0]
315+
slots = user_prompt['slots']
316+
slots['source_lang'] = 'en'
317+
slots['target_lang'] = 'en'
318+
slots['task'] = 'asr'
319+
slots['pnc'] = 'no'
320+
ans = asr_model.prompt.encode_dialog([user_prompt])
321+
recovered = asr_model.tokenizer.ids_to_text(ans["input_ids"])
322+
assert recovered == "<|startoftranscript|><|en|><|transcribe|><|en|><|nopnc|>"
323+
278324
@pytest.mark.unit
279325
def test_transcribe_single_file(self, asr_model, test_data_dir):
280326
audio_file = os.path.join(test_data_dir, "asr", "train", "an4", "wav", "an46-mmap-b.wav")

0 commit comments

Comments
 (0)