[Ready] track arc_map_token during rnnt decoding#1094
[Ready] track arc_map_token during rnnt decoding#1094csukuangfj merged 7 commits intok2-fsa:masterfrom
Conversation
|
In my opinion this is not really in the correct spirit of implementing backprop. What we should be doing, I think, is replacing the scores in the lattice with scores looked-up, via arc_map_token, from the log_probs tensor, and from the decoding graph. [I'm not sure of the details of the backprop; possibly just assigning the tensor would work... see what we do in functions for existing FSA operations] These scores should have the same numerical values, but should have the correct backprop information. |
It's a good idea and helps me figure out a way to do "unittest". An example usage now:
UnittestNow a simple "test" is added for scores looking up by k2/k2/python/k2/rnnt_decode.py Lines 273 to 279 in e5a8e4e |
k2/python/k2/mwer_loss.py
Outdated
|
|
||
| # Group path_logp into [stream][path] to compute denominator of each stream. | ||
| # A stream here means an input wav. | ||
| ragged_path_prob = k2.RaggedTensor(nbest.shape, path_logp.exp()) |
There was a problem hiding this comment.
You could perhaps write a softmax function accepting RaggedTensor to replace these few lines, and call that instead, it will be clearer.
There was a problem hiding this comment.
A function k2.ragged_softmax is added to k2/ops.py.
k2/python/k2/mwer_loss.py
Outdated
|
|
||
| # Copied from icefall. | ||
| # TODO(liyong) simplify this. | ||
| class Nbest(object): |
There was a problem hiding this comment.
If this is going to be used it should probably be in a different file. But are you sure it's not already in k2?
There was a problem hiding this comment.
Yes, class Nbest also exists in k2.
But it has fewer functions than that in icefall, i.e. from_lattice and build_levenshtein_graphs. So I just copied the icefall one here.
Now in the latest pr, these two needed functions are added into the k2.Nbest.
k2/python/k2/mwer_loss.py
Outdated
|
|
||
|
|
||
| class MWERLoss(nn.Module): | ||
| # Minimus Word Error Rate. |
There was a problem hiding this comment.
I am concerned that, especially for long utterances, the dynamic range of scores will be too large and the posteriors will be mostly 0 or 1. To prevent this it might be a good idea to have an extra argument that functions like a temperature-- something we can scale the logprobs by before doing the softmax.
|
The most important unit test for this pr is following assertion: k2/k2/python/k2/rnnt_decode.py Lines 284 to 286 in d83d79a
If |
pkufool
left a comment
There was a problem hiding this comment.
LGTM
Leave some comments here. It would be great if you can run the fast_beam_search and fast_beam_search_LG in icefall with this change to double check its correctness.
| int32_t arc_label = 0; | ||
| if (arc.label != -1) { | ||
| arc_label = arc.label; | ||
| } |
There was a problem hiding this comment.
I think you can set the arc_map_token to -1 for final arcs, and k2.index_select can map -1 to default value automatically.
| @@ -173,16 +175,27 @@ def format_output( | |||
| If false, we only care about the real final state in the | |||
| decoding graph on the last frame when generating lattice. | |||
| Default False. | |||
There was a problem hiding this comment.
Please add documents for log_probs.
k2/python/k2/rnnt_decode.py
Outdated
| scores_tracked_by_autograd = torch.index_select( | ||
| log_probs.reshape(-1), 0, arc_map_token) | ||
| final_arc_index = torch.where(fsa.arcs.values()[:, 2] == -1) | ||
| scores_tracked_by_autograd[final_arc_index] *= 0 |
There was a problem hiding this comment.
See the comment above, if you add -1 to arc_map_token and use k2.index_select you can avoid line 285-286.
| // i.e. padded beginning frames. | ||
| // We need to subtract number of padded frames when calculating | ||
| // the real time index. | ||
| auto num_padded_frames = Array1<int32_t>(c_, num_frames); |
There was a problem hiding this comment.
You initialize num_padded_frames with real frames here, it takes me a while to get your idea till I read the following lines. Please add some comments here.
k2/python/k2/rnnt_decode.py
Outdated
| final_arc_index = torch.where(fsa.arcs.values()[:, 2] == -1) | ||
| scores_tracked_by_autograd[final_arc_index] *= 0 | ||
| # This assertion statement is kind of unit test. | ||
| assert torch.all(fsa.scores == scores_tracked_by_autograd) |
There was a problem hiding this comment.
I think it is not right if we are using an LG graph (i.e. there are scores in decoding_graph).
Good idea. Here is a comparison with this model : with fast_beam_search
with fast_beam_search_LG
|
@glynpu Could you tell me how you obtain |
Found the answer here: k2/k2/python/tests/rnnt_decode_test.py Line 138 in 42e92fd |
Updated on Nov. 30, 2022:
k2.mwer is already merged by #1103
This pr will focus on tracking arc_map_token during decoding.
Original:
An implementation of MWER, equation 2 of https://arxiv.org/pdf/2106.02302.pdf
Also solving issue #1061
An example usage is:
lattice,log_probs,arc_map_tokenUnittest
Now I only manually check whether the gradient backprogates to the correct tokens, and it seems no obvious bugs.
Need more checks and unittests about this.