Add cross-attention to output hypotheses#15229
Conversation
Signed-off-by: Marco Gaido <mgaido@fbk.eu>
2de6160 to
21d5bb8
Compare
nithinraok
left a comment
There was a problem hiding this comment.
Thanks Marco. great work. Added comments. Also,
Could you add an option something like preserve_xattn_scores, so when enabled through
decoding_cfg = MultiTaskDecodingConfig(
strategy="beam", # or "greedy"
preserve_xattn_scores=True,
)only store and return xattn_scores (to save memory by default)
|
|
||
| last_frame (Optional): Index of the last decoding step hypothesis was updated including blank token prediction. | ||
|
|
||
| xatt_scores (Optional): List of cross-attention scores for each decoder layer. Each element of the list |
There was a problem hiding this comment.
Shouldn;t shape is List[BxHxT1xT2] . Also best to add: this is used with AED models
There was a problem hiding this comment.
this is for a single hypothesis, so there is no B... So this would be List[HxT1xT2]. If you prefer I can rename HxUxT to HxT1xT2. I will also add the indication about AED models.
| ) | ||
| if xatt_scores_list is not None: | ||
| for layer in range(len(xatt_scores_list)): | ||
| xatt_scores_list[layer] = torch.cat((xatt_scores_list[layer], new_xatt_scores_list[layer]), dim=2) |
There was a problem hiding this comment.
what about condition when new_xattn_scores_list is None? cat would fail
There was a problem hiding this comment.
if new_xattn_scores_list is None, xatt_scores_list will stay None, so we never enter in this if
There was a problem hiding this comment.
I meant to ask for each step, but probably its fine.
There was a problem hiding this comment.
I got it but if xatt_scores_list is not None, then new_xattn_scores_list must be not None, otherwise there is something weird happening (ie. attention scores are returned for some tokens and for some others are None in the same generation). I'd rather add an assert on new_xattn_scores_list, WDYT?
| pos=0, | ||
| return_scores: bool = True, | ||
| ): | ||
| log_probs, decoder_mems_list, _ = super()._one_step_forward( |
There was a problem hiding this comment.
could you update here as well and also include in returns tuple
|
@nithinraok thanks for your review. I do have a question, though. You said to add a param in |
Very good point. Well, first, the option need to be passed to
multitask_decoding = MultiTaskDecodingConfig()
multitask_decoding.strategy = "greedy"
multitask_decoding.return_xattn_scores = Trueand call asr_model.change_decoding_strategy(multitask_decoding)before performing .transcribe() Lets go with later option for now. I will keep thinking about this as this needs to be changed across. Are there any other options like these you are interested on to be changed through .transcribe()? |
Ok, I will work on this in the next days. Maybe it would be worth adding some checks and logs to guide the user though. I will try to come up with a proposal for that while working on this. Thanks. |
Signed-off-by: Marco Gaido <mgaido@fbk.eu>
Signed-off-by: mgaido91 <mgaido91@users.noreply.github.com> Signed-off-by: Marco Gaido <mgaido@fbk.eu>
Signed-off-by: Marco Gaido <mgaido@fbk.eu>
Signed-off-by: mgaido91 <mgaido91@users.noreply.github.com> Signed-off-by: Marco Gaido <mgaido@fbk.eu>
2a1005f to
e06e0c4
Compare
Signed-off-by: Marco Gaido <mgaido@fbk.eu>
andrusenkoau
left a comment
There was a problem hiding this comment.
Hi @mgaido91, thank you for the great work! I have almost no questions. Let's wait your final changes with decoding config.
|
|
||
| def test_temperature_sampling_decoding(inputs, nnet): | ||
| gen = GreedySequenceGenerator(*nnet, temperature=10.0, n_samples=2) | ||
| gen = GreedySequenceGenerator(*nnet, return_xattn_scores=True, temperature=10.0, n_samples=2) |
There was a problem hiding this comment.
Could you add the check for both return_xattn_scores options (as above) here?
There was a problem hiding this comment.
yes, sure, I did not do it to minimize the CI cost, I am updating it, thanks!
|
Thanks @andrusenkoau !
I already made them. You can find the PS I do not see why |
Signed-off-by: Marco Gaido <mgaido@fbk.eu>
|
@nithinraok I think I addressed your comments, may you please take another look at this? |
nithinraok
left a comment
There was a problem hiding this comment.
Thanks Marco. Minor comment which you might have missed earlier. LGTM otherwise. Thanks for the PR, great work!
|
@nithinraok @andrusenkoau thank you for your guidance and reviews! |
Important
The
Update branchbutton must only be pressed in very rare occassions.An outdated branch is never blocking the merge of a PR.
Please reach out to the automation team before pressing that button.
What does this PR do ?
The PR adds the encoder-decoder cross-attention to the output hypotheses returned by ASR models.
Collection: ASR
Changelog
Usage
GitHub Actions CI
The Jenkins CI system has been replaced by GitHub Actions self-hosted runners.
The GitHub Actions CI will run automatically when the "Run CICD" label is added to the PR.
To re-run CI remove and add the label again.
To run CI on an untrusted fork, a NeMo user with write access must first click "Approve and run".
Before your PR is "Ready for review"
Pre checks:
PR Type:
If you haven't finished some of the above items you can still open "Draft" PR.
Who can review?
Anyone in the community is free to review the PR once the checks have passed.
Contributor guidelines contains specific people who can review PRs to various areas.
@nithinraok @andrusenkoau
Additional Information