Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion nemo/collections/asr/models/aed_multitask_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,14 @@ def transcribe(
# Check if chunking will be enabled
trcfg.enable_chunking = (is_one_audio or trcfg.batch_size == 1) and self.timestamps_asr_model is not None

if not trcfg.enable_chunking:
if trcfg.enable_chunking:
if self.decoding.cfg.get('return_xattn_scores', False):
logging.warning(
"When chunking is enabled, cross-attention scores will not be returned even though "
"`return_xattn_scores` is set to True. If you want to return the cross-attention scores "
"set `enable_chunking` to False in the MultiTaskTranscriptionConfig in override_config."
)
else:
logging.warning("Chunking is disabled. Please pass a single audio file or set batch_size to 1")

results = super().transcribe(audio=audio, override_config=trcfg)
Expand Down
117 changes: 99 additions & 18 deletions nemo/collections/asr/modules/transformer/transformer_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@

preserve_step_confidence: Bool flag which preserves the history of per-step confidence scores generated
during greedy decoding. When set to true, the results will contain additional List of tensor floats.
return_xattn_scores: Bool flag which indicates whether to keep and return the cross-attention scores
during greedy/beam search decoding. When set to true, the results will contain additional List of tensors.
confidence_method_cfg: A dict-like object which contains the method name and settings to compute per-step
confidence scores.
name: The method name (str).
Expand Down Expand Up @@ -102,6 +104,7 @@
n_samples=1,
temperature=None,
preserve_step_confidence=False,
return_xattn_scores=False,
confidence_method_cfg: Optional[DictConfig] = None,
):
super().__init__()
Expand All @@ -115,6 +118,7 @@
self.n_samples = n_samples
self.temperature = temperature
self.preserve_step_confidence = preserve_step_confidence
self.return_xattn_scores = return_xattn_scores

# set confidence calculation method
self.num_tokens = getattr(self.classifier.mlp, f'layer{self.classifier.mlp.layers - 1}').out_features
Expand Down Expand Up @@ -226,6 +230,7 @@
step_confidence = None

decoder_mems_list = None
xatt_scores_list = None
for i in range(max_generation_length):

if i == 0:
Expand All @@ -234,14 +239,22 @@
i += tgt_len - 1
input_ids = tgt[:, -1:]

logits, decoder_mems_list, _ = self._one_step_forward(
logits, decoder_mems_list, new_xatt_scores_list = self._one_step_forward(
input_ids,
encoder_hidden_states,
encoder_input_mask,
decoder_mems_list,
i,
return_scores=return_beam_scores,
)
if self.return_xattn_scores:
if xatt_scores_list is not None:
for layer in range(len(xatt_scores_list)):
xatt_scores_list[layer] = torch.cat(
(xatt_scores_list[layer], new_xatt_scores_list[layer]), dim=2
)
else:
xatt_scores_list = new_xatt_scores_list

if self.temperature is None: # Greedy decoding
next_tokens = torch.argmax(logits[:, -1], dim=-1)
Expand Down Expand Up @@ -272,7 +285,7 @@
samples = list(tgt.view(orig_batch_size, self.n_samples, -1))
tgt = tgt[:: self.n_samples]

return tgt, samples, step_confidence_tensor
return tgt, samples, step_confidence_tensor, xatt_scores_list

def __call__(
self, decoder_input_ids=None, encoder_hidden_states=None, encoder_input_mask=None, return_beam_scores=False
Expand All @@ -284,12 +297,12 @@
if not return_beam_scores:
return results
else:
prefixes, scores, tgt = results
prefixes, scores, tgt, xatt_scores_list = results
prefixes = prefixes.view(-1, self.beam_size, tgt.size(1)).split(1, dim=0)
scores = scores.view(-1, self.beam_size).split(1, dim=0)
prefixes = [x.squeeze(0) for x in prefixes] # each item is [beam, seq_len]
scores = [x.squeeze(0) for x in scores] # each item is [beam,]
return prefixes, scores, tgt
return prefixes, scores, tgt, xatt_scores_list

def freeze(self) -> None:
"""Freeze weights of embedding, decoder, and classification layers to prevent memory leak."""
Expand Down Expand Up @@ -413,9 +426,11 @@
tgt, batch_size, max_generation_length = self._prepare_for_search(decoder_input_ids, encoder_hidden_states)

# generate initial buffer of beam_size prefixes-hypotheses
log_probs, decoder_mems_list, _ = self._one_step_forward(
log_probs, decoder_mems_list, xatt_scores_list = self._one_step_forward(
tgt, encoder_hidden_states, encoder_input_mask, None, 0
)
if not self.return_xattn_scores:
xatt_scores_list = None
scores, prefixes = torch.topk(log_probs.permute(0, 2, 1), self.beam_size, dim=1)
scores, prefixes = scores.view(-1, 1), prefixes.view(-1, 1)

Expand All @@ -434,6 +449,10 @@
else:
hidden_size = decoder_mems_list[0].size(2)

# repeat xattn scores
if xatt_scores_list is not None:
xatt_scores_list = [xatt_layer.repeat(self.beam_size, 1, 1, 1) for xatt_layer in xatt_scores_list]

# pad_profile tracks finished hypotheses to generate only <pad> tokens
# if <eos> or <pad> has been generated
pad_profile = torch.zeros_like(scores).long()
Expand All @@ -449,7 +468,7 @@
pad_mask = pad_profile.repeat(1, self.beam_size)

# generate and score candidates for prefixes continuation
log_probs, decoder_mems_list, _ = self._one_step_forward(
log_probs, decoder_mems_list, next_xatt_scores_list = self._one_step_forward(
prefixes[:, -1:], encoder_hidden_states, encoder_input_mask, decoder_mems_list, i
)
scores_i, prefixes_i = torch.topk(log_probs[:, -1, :], self.beam_size, dim=-1)
Expand Down Expand Up @@ -478,6 +497,21 @@
prefixes_ids = indices_i.unsqueeze(2).repeat(1, 1, p_len)
prefixes = prefixes.gather(1, prefixes_ids).view(-1, p_len)

# select xatt scores corresponding to chosen hypotheses
if self.return_xattn_scores and next_xatt_scores_list is not None:
num_heads = xatt_scores_list[0].shape[1]
xatt_indices_i = (
indices_i.unsqueeze(2).unsqueeze(3).unsqueeze(4).repeat(1, 1, num_heads, p_len - 1, src_length)
// self.beam_size
)
for layer in range(len(next_xatt_scores_list)):
xatt_layer_score_i = torch.cat((xatt_scores_list[layer], next_xatt_scores_list[layer]), dim=2)
xatt_scores_list[layer] = (
xatt_layer_score_i.view(-1, self.beam_size, num_heads, p_len - 1, src_length)
.gather(1, xatt_indices_i)
.view(-1, num_heads, p_len - 1, src_length)
)

# reshuffle cached decoder memory states to restore the order
# of hypotheses broken after top-k selection
mems_ids = indices_i.unsqueeze(2).unsqueeze(3).repeat(1, 1, p_len - 1, hidden_size) // self.beam_size
Expand All @@ -501,13 +535,26 @@
# select best performing hypotheses in each element of the batch
len_penalties = self.compute_len_penalty(prefixes_len, self.len_pen)
scores = scores / len_penalties
best_guesses = (
torch.argmax(scores.view(-1, self.beam_size), dim=1, keepdim=True).repeat(1, prefixes.size(1)).unsqueeze(1)
)
tgt = prefixes.view(batch_size, self.beam_size, -1).gather(1, best_guesses).squeeze(1)
best_guesses = torch.argmax(scores.view(-1, self.beam_size), dim=1, keepdim=True)
tgt_best_guesses = best_guesses.repeat(1, prefixes.size(1)).unsqueeze(1)
tgt = prefixes.view(batch_size, self.beam_size, -1).gather(1, tgt_best_guesses).squeeze(1)

# select xatt scores for best hypotheses
if xatt_scores_list is not None:
_, num_heads, tgt_len, src_len = xatt_scores_list[0].shape
xatt_best_guesses = (
best_guesses.unsqueeze(2).unsqueeze(3).unsqueeze(4).repeat(1, 1, num_heads, tgt_len, src_len)
)
for layer in range(len(xatt_scores_list)):
xatt_scores_list[layer] = (
xatt_scores_list[layer]
.view(-1, self.beam_size, num_heads, tgt_len, src_len)
.gather(1, xatt_best_guesses)
.squeeze(1)
)

if return_beam_scores:
return prefixes, scores * len_penalties, tgt
return prefixes, scores * len_penalties, tgt, xatt_scores_list
else:
return tgt

Expand Down Expand Up @@ -549,9 +596,11 @@
batch_fusion_states_candidates_list = []

# generate initial buffer of beam_size prefixes-hypotheses
log_probs, decoder_mems_list, _ = self._one_step_forward(
log_probs, decoder_mems_list, xatt_scores_list = self._one_step_forward(
tgt, encoder_hidden_states, encoder_input_mask, None, 0
)
if not self.return_xattn_scores:
xatt_scores_list = None
# get fusion models scores
for fusion_model_idx, fusion_model in enumerate(self.fusion_models):
fusion_scores, batch_fusion_states_candidates = fusion_model.advance(
Expand Down Expand Up @@ -585,6 +634,10 @@
else:
hidden_size = decoder_mems_list[0].size(2)

# repeat xattn scores
if xatt_scores_list is not None:
xatt_scores_list = [xatt_layer.repeat(self.beam_size, 1, 1, 1) for xatt_layer in xatt_scores_list]

# pad_profile tracks finished hypotheses to generate only <pad> tokens
# if <eos> or <pad> has been generated
pad_profile = torch.zeros_like(scores).long()
Expand All @@ -600,7 +653,7 @@
pad_mask = pad_profile.repeat(1, self.beam_size)

# generate and score candidates for prefixes continuation
log_probs, decoder_mems_list, _ = self._one_step_forward(
log_probs, decoder_mems_list, next_xatt_scores_list = self._one_step_forward(
prefixes[:, -1:], encoder_hidden_states, encoder_input_mask, decoder_mems_list, i
)
for fusion_model_idx, fusion_model in enumerate(self.fusion_models):
Expand Down Expand Up @@ -647,6 +700,21 @@
prefixes_ids = indices_i.unsqueeze(2).repeat(1, 1, p_len)
prefixes = prefixes.gather(1, prefixes_ids).view(-1, p_len)

# select xatt scores corresponding to chosen hypotheses
if self.return_xattn_scores and next_xatt_scores_list is not None:
num_heads = xatt_scores_list[0].shape[1]
xatt_indices_i = (
indices_i.unsqueeze(2).unsqueeze(3).unsqueeze(4).repeat(1, 1, num_heads, p_len - 1, src_length)
// self.beam_size
)
for layer in range(len(next_xatt_scores_list)):
xatt_layer_score_i = torch.cat((xatt_scores_list[layer], next_xatt_scores_list[layer]), dim=2)
xatt_scores_list[layer] = (
xatt_layer_score_i.view(-1, self.beam_size, num_heads, p_len - 1, src_length)
.gather(1, xatt_indices_i)
.view(-1, num_heads, p_len - 1, src_length)
)

# reshuffle cached decoder memory states to restore the order
# of hypotheses broken after top-k selection
mems_ids = indices_i.unsqueeze(2).unsqueeze(3).repeat(1, 1, p_len - 1, hidden_size) // self.beam_size
Expand All @@ -670,13 +738,26 @@
# select best performing hypotheses in each element of the batch
len_penalties = self.compute_len_penalty(prefixes_len, self.len_pen)
scores = scores / len_penalties
best_guesses = (
torch.argmax(scores.view(-1, self.beam_size), dim=1, keepdim=True).repeat(1, prefixes.size(1)).unsqueeze(1)
)
tgt = prefixes.view(batch_size, self.beam_size, -1).gather(1, best_guesses).squeeze(1)
best_guesses = torch.argmax(scores.view(-1, self.beam_size), dim=1, keepdim=True)
tgt_best_guesses = best_guesses.repeat(1, prefixes.size(1)).unsqueeze(1)
tgt = prefixes.view(batch_size, self.beam_size, -1).gather(1, tgt_best_guesses).squeeze(1)

# select xatt scores for best hypotheses
if xatt_scores_list is not None:
_, num_heads, tgt_len, src_len = xatt_scores_list[0].shape
xatt_best_guesses = (
best_guesses.unsqueeze(2).unsqueeze(3).unsqueeze(4).repeat(1, 1, num_heads, tgt_len, src_len)
)
for layer in range(len(xatt_scores_list)):
xatt_scores_list[layer] = (
xatt_scores_list[layer]
.view(-1, self.beam_size, num_heads, tgt_len, src_len)
.gather(1, xatt_best_guesses)
.squeeze(1)
)

if return_beam_scores:
return prefixes, scores * len_penalties, tgt
return prefixes, scores * len_penalties, tgt, xatt_scores_list
else:
return tgt

Expand Down Expand Up @@ -969,15 +1050,15 @@
len_penalties = self.compute_len_penalty(prefixes_len, self.len_pen)
scores = scores / len_penalties
best_guesses = (
torch.argmax(scores.view(-1, self.beam_size), dim=1, keepdim=True).repeat(1, prefixes.size(1)).unsqueeze(1)
)
tgt = prefixes.view(batch_size, self.beam_size, -1).gather(1, best_guesses).squeeze(1)

if return_beam_scores:
return prefixes, scores * len_penalties, tgt
else:
return tgt

Check warning

Code scanning / CodeQL

Signature mismatch in overriding method Warning

This method does not accept keyword argument return_scores, which overridden
GreedySequenceGenerator._one_step_forward
does.
This call
correctly calls the base method, but does not match the signature of the overriding method.
def __call__(self, src_ids, encoder_input_mask, decoder_input_ids=None, return_beam_scores=False):
with torch.inference_mode():
return self._forward(src_ids, encoder_input_mask, decoder_input_ids, return_beam_scores)
Expand Down
17 changes: 14 additions & 3 deletions nemo/collections/asr/parts/submodules/multitask_beam_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@


def pack_hypotheses(
hypotheses: List[Hypothesis], beam_hypotheses: torch.Tensor, scores: List[Optional[float]]
hypotheses: List[Hypothesis],
beam_hypotheses: torch.Tensor,
scores: List[Optional[float]],
xatt_scores_list: List[torch.Tensor] = None,
) -> List[Hypothesis]:

for idx, hyp in enumerate(hypotheses): # type: Hypothesis
Expand All @@ -49,6 +52,9 @@ def pack_hypotheses(
if hyp.dec_state is not None:
hyp.dec_state = _states_to_device(hyp.dec_state)

if xatt_scores_list is not None:
hyp.xatt_scores = [xatt_layer[idx] for xatt_layer in xatt_scores_list]

return hypotheses


Expand Down Expand Up @@ -139,6 +145,7 @@ def __init__(
ngram_lm_alpha: float = 0.0,
boosting_tree: BoostingTreeModelConfig | None = None,
boosting_tree_alpha: float = 0.0,
return_xattn_scores: bool = False,
):
super().__init__(
transformer_decoder=transformer_decoder,
Expand Down Expand Up @@ -181,6 +188,7 @@ def __init__(
eos=self.eos,
len_pen=length_penalty,
max_delta_length=max_generation_delta,
return_xattn_scores=return_xattn_scores,
)
else:
self.beam_search = BeamSearchSequenceGeneratorWithFusionModels(
Expand All @@ -196,6 +204,7 @@ def __init__(
max_delta_length=max_generation_delta,
fusion_models=fusion_models,
fusion_models_alpha=fusion_models_alpha,
return_xattn_scores=return_xattn_scores,
)

self.preserve_alignments = preserve_alignments
Expand Down Expand Up @@ -229,7 +238,7 @@ def forward(
self.transformer_decoder.eval()
self.log_softmax_module.eval()

topk_hypotheses, beam_scores, best_hypo = self.beam_search(
topk_hypotheses, beam_scores, best_hypo, xatt_scores_list = self.beam_search(
encoder_hidden_states=encoder_hidden_states,
encoder_input_mask=encoder_input_mask,
decoder_input_ids=decoder_input_ids,
Expand All @@ -249,11 +258,13 @@ def forward(
else:
beam_scores = [None for _ in range(len(best_hypo))]
best_hypo = best_hypo.detach().cpu()
if xatt_scores_list is not None:
xatt_scores_list = [xatt_layer.detach().cpu() for xatt_layer in xatt_scores_list]
hypotheses = [
Hypothesis(score=0.0, y_sequence=[], timestamp=[]) for _ in range(encoder_hidden_states.shape[0])
]
# Pack results into Hypotheses
packed_result = pack_hypotheses(hypotheses, best_hypo, beam_scores)
packed_result = pack_hypotheses(hypotheses, best_hypo, beam_scores, xatt_scores_list)
self.format_hypotheses(packed_result, decoder_input_ids)

self.transformer_decoder.train()
Expand Down
5 changes: 5 additions & 0 deletions nemo/collections/asr/parts/submodules/multitask_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ def change_strategy(self, strategy: str) -> "AbstractMultiTaskDecoding":
confidence_method_cfg=self.confidence_method_cfg,
temperature=self.cfg.greedy.temperature,
n_samples=self.cfg.greedy.n_samples,
return_xattn_scores=self.cfg.get('return_xattn_scores', False),
)

elif strategy == 'beam':
Expand All @@ -206,6 +207,7 @@ def change_strategy(self, strategy: str) -> "AbstractMultiTaskDecoding":
ngram_lm_alpha=self.cfg.beam.get('ngram_lm_alpha', 0.0),
boosting_tree=self.cfg.beam.get('boosting_tree', None),
boosting_tree_alpha=self.cfg.beam.get('boosting_tree_alpha', 0.0),
return_xattn_scores=self.cfg.get('return_xattn_scores', False),
)

else:
Expand Down Expand Up @@ -658,3 +660,6 @@ class MultiTaskDecodingConfig:

# can be used to change temperature for decoding
temperature: float = 1.0

# if set to true, return attention scores; ignore them to save memory otherwise
return_xattn_scores: bool = False
Loading
Loading