Skip to content

[Ready] track arc_map_token during rnnt decoding#1094

Merged
csukuangfj merged 7 commits intok2-fsa:masterfrom
glynpu:mwer
May 22, 2023
Merged

[Ready] track arc_map_token during rnnt decoding#1094
csukuangfj merged 7 commits intok2-fsa:masterfrom
glynpu:mwer

Conversation

@glynpu
Copy link
Contributor

@glynpu glynpu commented Nov 8, 2022

Updated on Nov. 30, 2022:

k2.mwer is already merged by #1103

This pr will focus on tracking arc_map_token during decoding.

Original:

An implementation of MWER, equation 2 of https://arxiv.org/pdf/2106.02302.pdf

Also solving issue #1061

An example usage is:

  1. slightly modification of fast_beam_search to return log_probs of each time step.

log_probs_list = []
for t in range(T):
    *****(code of decoding)
    log_probs = (logits / temperature).log_softmax(dim=-1) 
    log_probs_list.append(log_probs)
    decoding_streams.advance(log_probs)
decoding_streams.terminate_and_flush_to_streams()
log_probs = torch.cat(log_probs_list) 
lattice, arc_map_token = decoding_streams.format_output(encoder_out_lens.tolist(), allow_partial=False)
return lattice, log_probs, arc_map_token 
  1. Compute MWER based on the lattice, log_probs, arc_map_token
lattice, log_probs, arc_map_token = fast_beam_search(                                                                                                                       
    model=model,                                                                                                                                                            
    decoding_graph=decoding_graph,                                                                                                                                          
    encoder_out=encoder_out,                                                                                                                                                
    encoder_out_lens=encoder_out_lens,                                                                                                                                      
    beam=beam,                                                                                                                                                              
    max_states=max_states,                                                                                                                                                  
    max_contexts=max_contexts,                                                                                                                                              
    temperature=temperature,                                                                                                                                                
)  
mwer_loss = k2.mwer_loss(nbest_scale, num_paths, log_probs, lattice, arc_map_token, ref_texts)
mwer_loss.backward()

Unittest

Now I only manually check whether the gradient backprogates to the correct tokens, and it seems no obvious bugs.
Need more checks and unittests about this.

        # After loss.backward(),
        # users could manually check if the grad go to the right place,
        # An example:
        # By comapring hyps.labels and torch.where(log_probs.grad != 0)[1],
        # at least we could say that gradients seems to backpropate to the correct token,
        # i.e. 0, 15, 33, 10, 101 in following examples.
        # But why does token 95 in hyps.lables fail to get a gradient?
        # hyps.labels
        # >>>  tensor([  0,   0,  15,   0,  -1,   0,   0,  33,   0,  -1,   0,   0,  95,   0,
        #               -1,   0,   0, 101,   0,  -1,   0,   0,  10,   0,   0,  10,   0,  -1,
        #                0,   0,  10,   0,   0,  10,   0,   0,  10,   0,  -1],
        #             device='cuda:0', dtype=torch.int32)
        #
        # torch.where(log_probs.grad != 0)[0]  # corresopnds to context_index
        # >>>  tensor([ 0,  1,  1,  5,  6,  9, 10, 13, 14, 17, 18, 21, 22, 25, 26, 29, 30, 33,
        #               34, 38, 39, 39, 40, 44, 45, 46, 48, 49, 51, 52], device='cuda:0')
        #
        # torch.where(log_probs.grad != 0)[1]  # corresponds to token
        # >>>  tensor([  0,   0,  15,   0,   0,  33,   0,   0,   0,   0,   0,   0,   0,   0,
        #                0,   0,   0,   0,   0,  10,  10, 101,  10,  10,   0,  10,   0,   0,
        #                0,   0], device='cuda:0')

@danpovey
Copy link
Collaborator

danpovey commented Nov 8, 2022

In my opinion this is not really in the correct spirit of implementing backprop. What we should be doing, I think, is replacing the scores in the lattice with scores looked-up, via arc_map_token, from the log_probs tensor, and from the decoding graph. [I'm not sure of the details of the backprop; possibly just assigning the tensor would work... see what we do in functions for existing FSA operations] These scores should have the same numerical values, but should have the correct backprop information.
Then all kinds of lattice/FSA operations-- not just this specific one-- would work correctly.

@glynpu
Copy link
Contributor Author

glynpu commented Nov 9, 2022

replacing the scores in the lattice with scores looked-up, via arc_map_token, from the log_probs tensor, and from the decoding graph

It's a good idea and helps me figure out a way to do "unittest".

An example usage now:

  1. slightly modification of fast_beam_search to get a lattice whose scores are already tracked by auto-grad machine.

log_probs_list = []
t2stream_row_splits = [0]
stream2context_row_splits = [0]
num_log_probs = 0 
for t in range(T):
    *****(code of decoding)
    log_probs = (logits / temperature).log_softmax(dim=-1) 
    log_probs_list.append(log_probs)
    t2stream_row_splits += [shape.tot_size(0) + t2stream_row_splits[-1]]                                                                                                
    stream2context_row_splits += (shape.row_splits(1) + num_log_probs)[1:].tolist()                                                                                     
    num_log_probs = stream2context_row_splits[-1] 
    decoding_streams.advance(log_probs)
decoding_streams.terminate_and_flush_to_streams()
log_probs = torch.cat(log_probs_list) 
# NOTE: now lattice.scores is already tracked by auto-grad machine.
lattice = decoding_streams.format_output(encoder_out_lens.tolist(), log_probs=log_probs, t2s2c_shape=t2stream2context_shape3)
return lattice
  1. lattice and ref_texts are enough to compute MWER.
    i.e. current k2.mwer loss doesn't care how the lattice generated.
    It could be generated from a rnnt model by fast_beam_search,
    or from a ctc model by intersect_dense_pruned(we need to make the lattice.scores tracked by auto-grad machine).
mwer_loss = k2.mwer_loss(nbest_scale, num_paths, lattice, ref_texts)
mwer_loss.backward()

Unittest

Now a simple "test" is added for scores looking up by arc_map_token from log_probs.
With line-277, at least we could say arc_map_token is generated correctly.

# Make fsa.scores tracked by auto grad.
scores_tracked_by_auto_grad = torch.index_select(log_probs.reshape(-1), 0, arc_map_token)
final_arc_index = torch.where(fsa.arcs.values()[:, 2] == -1)
scores_tracked_by_auto_grad[final_arc_index] *= 0
assert torch.all(fsa.scores == scores_tracked_by_auto_grad)
fsa.scores = scores_tracked_by_auto_grad
return fsa

Copy link
Collaborator

@danpovey danpovey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some comments..


# Group path_logp into [stream][path] to compute denominator of each stream.
# A stream here means an input wav.
ragged_path_prob = k2.RaggedTensor(nbest.shape, path_logp.exp())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could perhaps write a softmax function accepting RaggedTensor to replace these few lines, and call that instead, it will be clearer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A function k2.ragged_softmax is added to k2/ops.py.


# Copied from icefall.
# TODO(liyong) simplify this.
class Nbest(object):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is going to be used it should probably be in a different file. But are you sure it's not already in k2?

Copy link
Contributor Author

@glynpu glynpu Nov 10, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, class Nbest also exists in k2.
But it has fewer functions than that in icefall, i.e. from_lattice and build_levenshtein_graphs. So I just copied the icefall one here.
Now in the latest pr, these two needed functions are added into the k2.Nbest.



class MWERLoss(nn.Module):
# Minimus Word Error Rate.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am concerned that, especially for long utterances, the dynamic range of scores will be too large and the posteriors will be mostly 0 or 1. To prevent this it might be a good idea to have an extra argument that functions like a temperature-- something we can scale the logprobs by before doing the softmax.

@glynpu glynpu changed the title [WIP] draft MWER [WIP] track arc_map_token during rnnt decoding Nov 30, 2022
@glynpu
Copy link
Contributor Author

glynpu commented Dec 2, 2022

The most important unit test for this pr is following assertion:

scores_tracked_by_autograd[final_arc_index] *= 0
assert torch.all(fsa.scores == scores_tracked_by_autograd)
fsa.scores = scores_tracked_by_autograd

fsa.scores is generated during decoding.
scores_tracked_by_autograd is selected by arc_map_token

If fsa.scores is identical to scores_tracked_by_autograd, it means arc_map_token is tracked correctly.

@glynpu glynpu changed the title [WIP] track arc_map_token during rnnt decoding [Ready] track arc_map_token during rnnt decoding Dec 6, 2022
Copy link
Collaborator

@pkufool pkufool left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Leave some comments here. It would be great if you can run the fast_beam_search and fast_beam_search_LG in icefall with this change to double check its correctness.

int32_t arc_label = 0;
if (arc.label != -1) {
arc_label = arc.label;
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can set the arc_map_token to -1 for final arcs, and k2.index_select can map -1 to default value automatically.

@@ -173,16 +175,27 @@ def format_output(
If false, we only care about the real final state in the
decoding graph on the last frame when generating lattice.
Default False.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add documents for log_probs.

scores_tracked_by_autograd = torch.index_select(
log_probs.reshape(-1), 0, arc_map_token)
final_arc_index = torch.where(fsa.arcs.values()[:, 2] == -1)
scores_tracked_by_autograd[final_arc_index] *= 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See the comment above, if you add -1 to arc_map_token and use k2.index_select you can avoid line 285-286.

// i.e. padded beginning frames.
// We need to subtract number of padded frames when calculating
// the real time index.
auto num_padded_frames = Array1<int32_t>(c_, num_frames);
Copy link
Collaborator

@pkufool pkufool Dec 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You initialize num_padded_frames with real frames here, it takes me a while to get your idea till I read the following lines. Please add some comments here.

final_arc_index = torch.where(fsa.arcs.values()[:, 2] == -1)
scores_tracked_by_autograd[final_arc_index] *= 0
# This assertion statement is kind of unit test.
assert torch.all(fsa.scores == scores_tracked_by_autograd)
Copy link
Collaborator

@pkufool pkufool Dec 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is not right if we are using an LG graph (i.e. there are scores in decoding_graph).

@glynpu
Copy link
Contributor Author

glynpu commented Dec 13, 2022

It would be great if you can run the fast_beam_search and fast_beam_search_LG in icefall with this change to double check its correctness.

Good idea. Here is a comparison with this model :

with fast_beam_search

- test-clean test-other
master branch 1.77 4.23
current pr 1.77 4.23

with fast_beam_search_LG

- test-clean test-other
master branch 2.15 4.45
current pr 2.15 4.45

Copy link
Collaborator

@pkufool pkufool left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+2

@pkufool pkufool added the ready Ready for review and trigger GitHub actions to run label Dec 14, 2022
@csukuangfj csukuangfj merged commit 3ebea1a into k2-fsa:master May 22, 2023
@desh2608
Copy link
Contributor

desh2608 commented Jul 7, 2023

t2stream2context_shape3

@glynpu Could you tell me how you obtain t2stream2context_shape3?

@desh2608
Copy link
Contributor

desh2608 commented Jul 7, 2023

t2stream2context_shape3

@glynpu Could you tell me how you obtain t2stream2context_shape3?

Found the answer here:

t2stream2context_shape3 = t2s_shape.compose(s2c_shape).to(device)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready Ready for review and trigger GitHub actions to run

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants