Skip to content

pruned_ragged_to_lattice#1163

Merged
pkufool merged 19 commits intok2-fsa:masterfrom
glynpu:selfalignment
May 12, 2023
Merged

pruned_ragged_to_lattice#1163
pkufool merged 19 commits intok2-fsa:masterfrom
glynpu:selfalignment

Conversation

@glynpu
Copy link
Contributor

@glynpu glynpu commented Feb 17, 2023

Hi Dan, this is the code we are discussing just now. @danpovey
Could we get an alignment you mentioned from this lattice?

Maybe we need more unit tests to make sure it works as we expect.
Currently, this is only checked by Xiaoyu's and My eyes with following code.

import k2
import torch
ranges = torch.tensor([[[0, 1, 2, 3, 4],[1,2,3,4,5],[2,3,4,5,6],[3,4,5,6,7]],                                                                                               
                       [[0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [3, 4, 5, 6, 7], [3,4,5,6,7]]]).to(torch.int32)                                                                   
x_lens = torch.tensor([4, 3]).to(torch.int32)                                                                                                                               
y = torch.tensor([[8, 7, 6, 5, 4, 3, 2],                                                                                                                                    
                  [2, 3, 4, 5, 6, 7, 8]]).to(torch.int32)                                                                                                                   
logits = torch.tensor([0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]).expand(2, 4, 5, 9).to(torch.float32)
logits = logits + torch.tensor([10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0]).reshape(2, 4).unsqueeze(-1).unsqueeze(-1)
logits = logits + torch.tensor([0.0, 1, 2, 3, 4]).unsqueeze(0).unsqueeze(0).unsqueeze(-1)                                                                                                                                                                      

ofsa, arc_map  = k2.self_alignment(ranges, x_lens, y, logits)                                                                                                               
lattice = k2.Fsa(ofsa)                                                                                                                                                      
scores_tracked_by_autograd = k2.index_select(logits.reshape(-1), arc_map)                                                                                                   
assert torch.all(lattice.scores == scores_tracked_by_autograd)

lattice.scores = scores_tracked_by_autograd

Lattice generated:
image

@danpovey
Copy link
Collaborator

So I assume best_path would be done after this.
There definitely should be a way to get the alignment info, either by somehow tracing back the arc_maps (I don't remember
whether these are made available), or perhaps more elegantly, by attaching some kind of integer properties
about the frame indexes and text-position indexes to the FSA at the point we create it, and then accessing those after best-path. Perhaps someone can figure out how to do this.
I think this is something we will find a lot of uses for, but we need to have a function in icefall that makes it available in an easy way.
We could perhaps have an option to the RNN-T training code, to return the alignment.

@marcoyang1998
Copy link

have a function in icefall

Actually, I am using this feature for a while and I have a function in python that converts the alignment to timestamp information. I will make a PR in icefall.

@danpovey
Copy link
Collaborator

Great!
Bear in mind we may need the scores/probabilities as well.

Copy link
Collaborator

@danpovey danpovey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some comments


namespace k2 {

FsaVec SelfAlignment(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we would need more documentation here. The names here seem to relate to the RNN-T code, and in addition to explaining what the various variables mean you should point to where the regular RNN-T API is located and give some indication of what the expected usage pattern would be, e.g. to call this after doing the regular RNN-T recursion.

I would have thought that alignment code would require a loop of some kind(?), over time, so I'm surprised to not see that here, at least from glancing at it.

case 1:
next_state_idx1 = ranges_data[range_offset] - ranges_data[range_offset_of_lower_bound_of_next_time_step];
// blank_connections_data_offset = fsa_idx0 * blk_stride_0 + t * blk_stride_1 + token_index * blk_stride_2;
// next_state_idx1 = blank_connections_data[blank_connections_data_offset];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is a fair amount of dead code here in comments.

K2_CHECK_EQ(torch::kInt, ranges.scalar_type());
K2_CHECK_EQ(torch::kInt, x_lens.scalar_type()); // int32_t
const float *logits_data = logits.data_ptr<float>();
const int32_t *ranges_data = ranges.data_ptr<int32_t>();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is another way of accessing these tensors other than extracting the strides etc., which is to use an accessor.
It might be something like:

auto range_accessor = ranges.accessor<int32_t, 3>();

or (and this may be more efficient due to being somehow more specialized or optimize, I'm not sure):

auto range_accessor = ranges.packed_accessor32<int32_t, 3>();

(the 3 is the number of dimensions). And then I think you'd access it from the lambda as
something like:

int32_t i = rng_accessor[idx0][idx1][idx2][idx3];

as if it were an array. I think you can also do things like rng_accessor.size(0) to get the
sizes on each dimension.

HOWEVER: this code may really belong in k2/python/k2/csrc/torch/; this is where we have other code
that makes use of torch APIs, I'm not sure if we really encourage use of torch arrays at this part of the code
(this organization may be a mistake, possibly, but it was intended to separate parts of k2 that
are too deeply entwined with the tensor toolkit).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is another way of accessing these tensors other than extracting the strides etc., which is to use an accessor.

I think both methods have their advantages and disadvantages.
Now, I don't know which one to choose. would you mind giving me a suggestion on which one should be used? @danpovey

For using accessor,
according to https://pytorch.org/cppdocs/notes/tensor_basics.html#cuda-accessors, It is thus recommended to use accessors for CPU tensors and packed accessors for CUDA tensors. ( I think the accessor is the reason for the crash when using a CUDA tensor as input. It's a shame that the unittest I implemented failed to cover this).

To support both cpu and cuda tensor, shall we compile two files pruned_ragged_to_lattice_cpu.cu and pruned_ragged_to_lattice_cuda.cu (similar tomutual_information_cuda.cu and mutual_information_cpu.cu)?

pruned_ragged_to_lattice_cpu.cu and pruned_ragged_to_lattice_cuda.cu will be very similar except one uses accessor and the other uses packed accessor.

For using data_ptr and extracting the strides etc,
only one file(pruned_ragged_to_lattice.cu) is enough to support both cpu and gpu tensors. But the disadvantage is we could not use AT_DISPATCH_FLOATING_TYPES_AND_HALF to make it compatible with different input types as

    // Type of logits can be float32, float64, float16. Though float32 is mainly
    // used, float64 and float16 are also supported for future use.
    AT_DISPATCH_FLOATING_TYPES_AND_HALF(
        logits.scalar_type(), "pruned_ranges_to_lattice", ([&] {
          auto logits_a = logits.accessor<scalar_t, 4>();

Since logits.scalar_type() is an enum so it's illegal to define const scalar_t * logits_data = logits.data_ptr().

Copy link
Collaborator

@danpovey danpovey Jun 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just use packed_accessor32. I think there is no problem for CPU. They don't state why not to use it for CPU, it may be about extremely large arrays and 32-bit indexes, but that isn't really an issue for us.

@glynpu glynpu changed the title Selfalignment pruned_ragged_to_lattice May 8, 2023
@danpovey
Copy link
Collaborator

danpovey commented May 8, 2023

Looks OK to me from a brief glance!
So are we OK to merge it? Let's merge today unless there are any immediate objections?

@pkufool pkufool added the ready Ready for review and trigger GitHub actions to run label May 11, 2023
@pkufool
Copy link
Collaborator

pkufool commented May 12, 2023

@glynpu Could you have a look to check if the failing cases relate to this PR.

@glynpu
Copy link
Contributor Author

glynpu commented May 12, 2023

@glynpu Could you have a look to check if the failing cases relate to this PR.

The failing cases are not related to this PR. They are mainly about torch.1.13.1 installation problem.

@pkufool
Copy link
Collaborator

pkufool commented May 12, 2023

I think it is a fairly safe and independent change (i.e. will not affect other functions), merging now.

@pkufool pkufool merged commit 9da8626 into k2-fsa:master May 12, 2023
e.g. arc pointing to super final state.
@return Return an FsaVec, which contains the generated lattice.
*/
FsaVec PrunedRangesToLattice(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
FsaVec PrunedRangesToLattice(
static FsaVec PrunedRangesToLattice(

Comment on lines +56 to +61
FsaVec PrunedRangesToLattice(
torch::Tensor ranges, // [B][T][s_range]
torch::Tensor frames, // [B]
torch::Tensor symbols, // [B][S]
torch::Tensor logits, // [B][T][s_range][C]
Array1<int32_t> *arc_map);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this function used by other functions other than the one
in PybindPrunedRangesToLattice?

If not, please remove it from this header file

int32_t * f2s_row_splits_data = f2s_row_splits.Data();
K2_EVAL(context, B, lambda_set_f2s_row_splits, (int32_t fsa_idx0) {
int32_t t = frames_a[fsa_idx0];
K2_CHECK_LE(t, T);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest that you use a debug check here, i.e.,
change K2_CHECK_LE to K2_DCHECK_LE.


// Compute f2s_shape: fsa_to_state_shape.
Array1<int32_t> f2s_row_splits(context, B + 1);
int32_t * f2s_row_splits_data = f2s_row_splits.Data();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

int32_t * f2s_row_splits_data = f2s_row_splits.Data();

to be consistent, please replace
int32_t * f2s_row_splits_data
with
int32_t *f2s_row_splits_data.

t = state_idx1 / s_range,
token_idx = state_idx1 % s_range;

K2_CHECK_LE(t, frames_a[fsa_idx0]);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

K2_DCHECK_LE

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready Ready for review and trigger GitHub actions to run

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants