-
Notifications
You must be signed in to change notification settings - Fork 233
pruned_ragged_to_lattice #1163
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
pruned_ragged_to_lattice #1163
Changes from all commits
32fbbba
5ef2092
dea8116
1613abc
b448667
8564112
b703774
ecc2db3
dbb590a
d00fd0c
fefd278
2218904
44e1927
15e645b
00a5cbd
d295262
482915c
5012a28
4b9329c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,282 @@ | ||
| /** | ||
| * @copyright | ||
| * Copyright 2022 Xiaomi Corporation (authors: Liyong Guo) | ||
| * | ||
| * @copyright | ||
| * See LICENSE for clarification regarding multiple authors | ||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| #include <utility> | ||
|
|
||
| #include "k2/csrc/device_guard.h" | ||
| #include "k2/csrc/fsa.h" | ||
| #include "k2/csrc/torch_util.h" | ||
| #include "k2/python/csrc/torch/pruned_ranges_to_lattice.h" | ||
| #include "k2/python/csrc/torch/v2/ragged_any.h" | ||
|
|
||
|
|
||
| namespace k2 { | ||
|
|
||
| /* | ||
| Convert pruned ranges to lattice while also supporting autograd. | ||
|
|
||
| The input pruned ranges is normally generated by `get_rnnt_prune_ranges`. | ||
| See k2/python/k2/rnnt_loss.py for the process of generating ranges and | ||
| the information it represents. | ||
|
|
||
| When this is implemented, the lattice is used to generate force-alignment. | ||
| Perhaps you could find other uses for this function. | ||
|
|
||
| @param ranges A tensor containing the symbol indexes for each frame. | ||
| Its shape is (B, T, s_range). See the docs in `get_rnnt_prune_ranges` | ||
| in k2/python/k2/rnnt_loss.py for more details of this tensor. | ||
| Its type is int32, consistent with that in rnnt_loss.py. | ||
| @param frames The number of frames per sample with shape (B). | ||
| Its type is int32. | ||
| @param symbols The symbol sequence, a LongTensor of shape (B, S), | ||
| and elements in {0..C-1}. | ||
| Its type is int64(Long), consistent with that in rnnt_loss.py. | ||
| @param logits The pruned joiner network (or am/lm) | ||
| of shape (B, T, s_range, C). | ||
| Its type can be float32, float64, float16. Though float32 is mainly | ||
| used, float64 and float16 are also supported for future use. | ||
| @param [out] arc_map A map from arcs in generated lattice to global index | ||
| of logits, or -1 if the arc had no corresponding score in logits, | ||
| e.g. arc pointing to super final state. | ||
| @return Return an FsaVec, which contains the generated lattice. | ||
| */ | ||
| FsaVec PrunedRangesToLattice( | ||
| torch::Tensor ranges, // [B][T][s_range] | ||
| torch::Tensor frames, // [B] | ||
| torch::Tensor symbols, // [B][S] | ||
| torch::Tensor logits, // [B][S][s_range][C] | ||
| Array1<int32_t> *arc_map) { | ||
|
|
||
| TORCH_CHECK(ranges.get_device() == frames.get_device()); | ||
| TORCH_CHECK(ranges.get_device() == symbols.get_device()); | ||
| TORCH_CHECK(ranges.get_device() == logits.get_device()); | ||
|
|
||
| TORCH_CHECK(ranges.dim() == 3, "ranges should be 3-dimensional"); | ||
| TORCH_CHECK(frames.dim() == 1, "frames should be 1-dimensional"); | ||
| TORCH_CHECK(symbols.dim() == 2, "symbols should be 2-dimensional"); | ||
| TORCH_CHECK(logits.dim() == 4, "logits should be 4-dimensional"); | ||
|
|
||
| TORCH_CHECK(torch::kInt == ranges.scalar_type()); | ||
| TORCH_CHECK(torch::kInt == frames.scalar_type()); | ||
| TORCH_CHECK(torch::kLong == symbols.scalar_type()); | ||
|
|
||
| ContextPtr context; | ||
| if (ranges.device().type() == torch::kCPU) { | ||
| context = GetCpuContext(); | ||
| } else if (ranges.is_cuda()) { | ||
| context = GetCudaContext(ranges.device().index()); | ||
| } else { | ||
| K2_LOG(FATAL) << "Unsupported device: " << ranges.device() | ||
| << "\nOnly CPU and CUDA are verified"; | ||
| } | ||
|
|
||
| // "_a" is short for accessor. | ||
| auto ranges_a = ranges.accessor<int32_t, 3>(); | ||
| auto frames_a = frames.accessor<int32_t, 1>(); | ||
| auto symbols_a = symbols.accessor<int64_t, 2>(); | ||
|
|
||
| // Typically, s_range is 5. | ||
| const int32_t B = ranges.size(0), | ||
| T = ranges.size(1), | ||
| s_range = ranges.size(2); | ||
|
|
||
| // Compute f2s_shape: fsa_to_state_shape. | ||
| Array1<int32_t> f2s_row_splits(context, B + 1); | ||
| int32_t * f2s_row_splits_data = f2s_row_splits.Data(); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. to be consistent, please replace |
||
| K2_EVAL(context, B, lambda_set_f2s_row_splits, (int32_t fsa_idx0) { | ||
| int32_t t = frames_a[fsa_idx0]; | ||
| K2_CHECK_LE(t, T); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I suggest that you use a debug check here, i.e., |
||
| // + 1 in "t * s_range + 1" is for super-final state. | ||
| f2s_row_splits_data[fsa_idx0] = t * s_range + 1; | ||
| }); | ||
|
|
||
| ExclusiveSum(f2s_row_splits, &f2s_row_splits); | ||
| RaggedShape f2s_shape = | ||
| RaggedShape2(&f2s_row_splits, nullptr, -1); | ||
|
|
||
| // Compute s2a_shape: state_to_arc_shape. | ||
| int32_t num_states = f2s_shape.NumElements(); | ||
| Array1<int32_t> s2c_row_splits(context, num_states + 1); | ||
| int32_t *s2c_row_splits_data = s2c_row_splits.Data(); | ||
| const int32_t *f2s_row_splits1_data = f2s_shape.RowSplits(1).Data(), | ||
| *f2s_row_ids1_data = f2s_shape.RowIds(1).Data(); | ||
| // Compute number of arcs for each state. | ||
| K2_EVAL( | ||
| context, num_states, lambda_set_num_arcs, (int32_t state_idx01)->void { | ||
| int32_t fsa_idx0 = f2s_row_ids1_data[state_idx01], | ||
| state_idx0x = f2s_row_splits1_data[fsa_idx0], | ||
| state_idx1 = state_idx01 - state_idx0x, | ||
| t = state_idx1 / s_range, | ||
| token_idx = state_idx1 % s_range; | ||
|
|
||
| K2_CHECK_LE(t, frames_a[fsa_idx0]); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
|
||
| // The state doesn't have leaving arc: super final_state. | ||
| if (state_idx1 == frames_a[fsa_idx0] * s_range) { | ||
| s2c_row_splits_data[state_idx01] = 0; | ||
| return; | ||
| } | ||
|
|
||
| // States have a leaving arc if no specially processed. | ||
| s2c_row_splits_data[state_idx01] = 1; | ||
|
|
||
| // States have two leaving arcs. | ||
| bool has_horizontal_blank_arc = false; | ||
| if (t < frames_a[fsa_idx0] - 1) { | ||
| has_horizontal_blank_arc = | ||
| ranges_a[fsa_idx0][t][token_idx] >= ranges_a[fsa_idx0][t + 1][0]; | ||
| } | ||
| // Typically, s_range == 5, i.e. 5 states for each time step. | ||
| // the 4th state(0-based index) ONLY has a horizontal blank arc. | ||
| // While state [0, 1, 2, 3] have a vertial arc | ||
| // and MAYBE a horizontal blank arc. | ||
| if (token_idx != s_range - 1 && has_horizontal_blank_arc) { | ||
| s2c_row_splits_data[state_idx01] = 2; | ||
| } | ||
| }); | ||
| ExclusiveSum(s2c_row_splits, &s2c_row_splits); | ||
| RaggedShape s2a_shape = | ||
| RaggedShape2(&s2c_row_splits, nullptr, -1); | ||
|
|
||
| // ofsa_shape: output_fsa_shape. | ||
| RaggedShape ofsa_shape = ComposeRaggedShapes(f2s_shape, s2a_shape); | ||
|
|
||
| int32_t num_arcs = ofsa_shape.NumElements(); | ||
| Array1<Arc> arcs(context, num_arcs); | ||
|
|
||
| Arc *arcs_data = arcs.Data(); | ||
| const int32_t *row_splits1_data = ofsa_shape.RowSplits(1).Data(), | ||
| *row_ids1_data = ofsa_shape.RowIds(1).Data(), | ||
| *row_splits2_data = ofsa_shape.RowSplits(2).Data(), | ||
| *row_ids2_data = ofsa_shape.RowIds(2).Data(); | ||
|
|
||
| Array1<int32_t> out_map(context, num_arcs); | ||
| int32_t* out_map_data = out_map.Data(); | ||
| // Used to populate out_map. | ||
| const int32_t lg_stride_0 = logits.stride(0), | ||
| lg_stride_1 = logits.stride(1), | ||
| lg_stride_2 = logits.stride(2), | ||
| lg_stride_3 = logits.stride(3); | ||
|
|
||
| // Type of logits can be float32, float64, float16. Though float32 is mainly | ||
| // used, float64 and float16 are also supported for future use. | ||
| AT_DISPATCH_FLOATING_TYPES_AND_HALF( | ||
| logits.scalar_type(), "pruned_ranges_to_lattice", ([&] { | ||
| auto logits_a = logits.accessor<scalar_t, 4>(); | ||
|
|
||
| K2_EVAL( | ||
| context, num_arcs, lambda_set_arcs, (int32_t arc_idx012)->void { | ||
| const int32_t state_idx01 = row_ids2_data[arc_idx012], | ||
| fsa_idx0 = row_ids1_data[state_idx01], | ||
| state_idx0x = row_splits1_data[fsa_idx0], | ||
| arc_idx01x = row_splits2_data[state_idx01], | ||
| state_idx1 = state_idx01 - state_idx0x, | ||
| arc_idx2 = arc_idx012 - arc_idx01x, | ||
| t = state_idx1 / s_range, | ||
| // token_idx lies within interval [0, s_range) | ||
| // but does not include s_range. | ||
| token_idx = state_idx1 % s_range; | ||
| Arc arc; | ||
| arc.src_state = state_idx1; | ||
| // The penultimate state only has a leaving arc | ||
| // pointing to super final state. | ||
| if (state_idx1 == frames_a[fsa_idx0] * s_range - 1) { | ||
| arc.src_state = state_idx1; | ||
| arc.dest_state = state_idx1 + 1; | ||
| arc.label = -1; | ||
| arc.score = 0.0; | ||
| arcs_data[arc_idx012] = arc; | ||
| out_map_data[arc_idx012] = -1; | ||
| return; | ||
| } | ||
|
|
||
| if (token_idx < s_range - 1) { | ||
| // States have a vertal arc with non-blank label and | ||
| // MAYBE a horizontal arc with blank label. | ||
| const int32_t symbol_idx = ranges_a[fsa_idx0][t][token_idx], | ||
| arc_label = symbols_a[fsa_idx0][symbol_idx]; | ||
| K2_CHECK_LE(arc_idx2, 2); | ||
| switch (arc_idx2) { | ||
| // For vertial arc with non-blank label. | ||
| case 0: | ||
| arc.dest_state = state_idx1 + 1; | ||
| arc.label = arc_label; | ||
| arc.score = logits_a[fsa_idx0][t][token_idx][arc_label]; | ||
|
|
||
| out_map_data[arc_idx012] = | ||
| fsa_idx0 * lg_stride_0 + t * lg_stride_1 + | ||
| token_idx * lg_stride_2 + arc_label * lg_stride_3; | ||
| break; | ||
| // For horizontal arc with blank label. | ||
| case 1: | ||
| const int32_t dest_state_token_idx = | ||
| ranges_a[fsa_idx0][t][token_idx] - | ||
| ranges_a[fsa_idx0][t + 1][0]; | ||
| K2_CHECK_GE(dest_state_token_idx, 0); | ||
| arc.dest_state = dest_state_token_idx + (t + 1) * s_range; | ||
| arc.label = 0; | ||
| arc.score = logits_a[fsa_idx0][t][token_idx][0]; | ||
|
|
||
| out_map_data[arc_idx012] = | ||
| fsa_idx0 * lg_stride_0 + t * lg_stride_1 + | ||
| token_idx * lg_stride_2; | ||
| break; | ||
| } | ||
| } else { | ||
| // States only have a horizontal arc with blank label. | ||
| K2_CHECK_EQ(arc_idx2, 0); | ||
| const int32_t dest_state_token_idx = | ||
| ranges_a[fsa_idx0][t][token_idx] - | ||
| ranges_a[fsa_idx0][t + 1][0]; | ||
| arc.dest_state = | ||
| dest_state_token_idx + (t + 1) * s_range; | ||
| arc.label = 0; | ||
| arc.score = logits_a[fsa_idx0][t][token_idx][0]; | ||
|
|
||
| out_map_data[arc_idx012] = | ||
| fsa_idx0 * lg_stride_0 + t * lg_stride_1 + | ||
| token_idx * lg_stride_2; | ||
| } | ||
| arcs_data[arc_idx012] = arc; | ||
| }); | ||
| *arc_map = std::move(out_map); | ||
| })); | ||
|
|
||
| return Ragged<Arc>(ofsa_shape, arcs); | ||
| } | ||
|
|
||
| } // namespace k2 | ||
|
|
||
| void PybindPrunedRangesToLattice(py::module &m) { | ||
| m.def( | ||
| "pruned_ranges_to_lattice", | ||
| [](torch::Tensor ranges, | ||
| torch::Tensor frames, | ||
| torch::Tensor symbols, | ||
| torch::Tensor logits) -> std::pair<k2::FsaVec, torch::Tensor> { | ||
| k2::DeviceGuard guard(k2::GetContext(ranges)); | ||
| k2::Array1<int32_t> arc_to_logit_map; | ||
| k2::FsaVec ofsa = k2::PrunedRangesToLattice( | ||
| ranges, frames, symbols, logits, &arc_to_logit_map); | ||
| return std::make_pair(ofsa, ToTorch(arc_to_logit_map)); | ||
| }, | ||
| py::arg("ranges"), py::arg("frames"), | ||
| py::arg("symbols"), py::arg("logits")); | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,67 @@ | ||
| /** | ||
| * @copyright | ||
| * Copyright 2022 Xiaomi Corporation (authors: Liyong Guo) | ||
| * | ||
| * @copyright | ||
| * See LICENSE for clarification regarding multiple authors | ||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| #ifndef K2_PYTHON_CSRC_TORCH_PRUNED_RANGES_TO_LATTICE_H_ | ||
| #define K2_PYTHON_CSRC_TORCH_PRUNED_RANGES_TO_LATTICE_H_ | ||
|
|
||
| #include "k2/python/csrc/torch.h" | ||
|
|
||
| namespace k2 { | ||
|
|
||
| /* | ||
| Convert pruned ranges to lattice while also supporting autograd. | ||
|
|
||
| The input pruned ranges is normally generated by `get_rnnt_prune_ranges`. | ||
| See k2/python/k2/rnnt_loss.py for the process of generating ranges and | ||
| the information it represents. | ||
|
|
||
| When this is implemented, the lattice is used to generate force-alignment. | ||
| Perhaps you could find other uses for this function. | ||
|
|
||
| @param ranges A tensor containing the symbol indexes for each frame. | ||
| Its shape is (B, T, s_range). See the docs in `get_rnnt_prune_ranges` | ||
| in k2/python/k2/rnnt_loss.py for more details of this tensor. | ||
| Its type is int32, consistent with that in rnnt_loss.py. | ||
| @param frames The number of frames per sample with shape (B). | ||
| Its type is int32. | ||
| @param symbols The symbol sequence, a LongTensor of shape (B, S), | ||
| and elements in {0..C-1}. | ||
| Its type is int64(Long), consistent with that in rnnt_loss.py. | ||
| @param logits The pruned joiner network (or am/lm) | ||
| of shape (B, T, s_range, C). | ||
| Its type can be float32, float64, float16. Though float32 is mainly | ||
| used, float64 and float16 are also supported for future use. | ||
| @param [out] arc_map A map from arcs in generated lattice to global index | ||
| of logits, or -1 if the arc had no corresponding score in logits, | ||
| e.g. arc pointing to super final state. | ||
| @return Return an FsaVec, which contains the generated lattice. | ||
| */ | ||
| FsaVec PrunedRangesToLattice( | ||
| torch::Tensor ranges, // [B][T][s_range] | ||
| torch::Tensor frames, // [B] | ||
| torch::Tensor symbols, // [B][S] | ||
| torch::Tensor logits, // [B][T][s_range][C] | ||
| Array1<int32_t> *arc_map); | ||
|
Comment on lines
+56
to
+61
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this function used by other functions other than the one If not, please remove it from this header file |
||
|
|
||
| } // namespace k2 | ||
|
|
||
| void PybindPrunedRangesToLattice(py::module &m); | ||
|
|
||
| #endif // K2_PYTHON_CSRC_TORCH_PRUNED_RANGES_TO_LATTICE_H_ | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.