diff --git a/k2/python/csrc/torch.cu b/k2/python/csrc/torch.cu index f62dfc148..72fe18151 100644 --- a/k2/python/csrc/torch.cu +++ b/k2/python/csrc/torch.cu @@ -31,6 +31,7 @@ #include "k2/python/csrc/torch/index_select.h" #include "k2/python/csrc/torch/mutual_information.h" #include "k2/python/csrc/torch/nbest.h" +#include "k2/python/csrc/torch/pruned_ranges_to_lattice.h" #include "k2/python/csrc/torch/ragged.h" #include "k2/python/csrc/torch/ragged_ops.h" #include "k2/python/csrc/torch/rnnt_decode.h" @@ -47,6 +48,7 @@ void PybindTorch(py::module &m) { PybindRagged(m); PybindRaggedOps(m); PybindRnntDecode(m); + PybindPrunedRangesToLattice(m); k2::PybindV2(m); } diff --git a/k2/python/csrc/torch/CMakeLists.txt b/k2/python/csrc/torch/CMakeLists.txt index 9a8cc2229..8c6803f8a 100644 --- a/k2/python/csrc/torch/CMakeLists.txt +++ b/k2/python/csrc/torch/CMakeLists.txt @@ -8,6 +8,7 @@ set(torch_srcs mutual_information.cu mutual_information_cpu.cu nbest.cu + pruned_ranges_to_lattice.cu ragged.cu ragged_ops.cu rnnt_decode.cu diff --git a/k2/python/csrc/torch/pruned_ranges_to_lattice.cu b/k2/python/csrc/torch/pruned_ranges_to_lattice.cu new file mode 100644 index 000000000..415be116a --- /dev/null +++ b/k2/python/csrc/torch/pruned_ranges_to_lattice.cu @@ -0,0 +1,282 @@ +/** + * @copyright + * Copyright 2022 Xiaomi Corporation (authors: Liyong Guo) + * + * @copyright + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "k2/csrc/device_guard.h" +#include "k2/csrc/fsa.h" +#include "k2/csrc/torch_util.h" +#include "k2/python/csrc/torch/pruned_ranges_to_lattice.h" +#include "k2/python/csrc/torch/v2/ragged_any.h" + + +namespace k2 { + +/* + Convert pruned ranges to lattice while also supporting autograd. + + The input pruned ranges is normally generated by `get_rnnt_prune_ranges`. + See k2/python/k2/rnnt_loss.py for the process of generating ranges and + the information it represents. + + When this is implemented, the lattice is used to generate force-alignment. + Perhaps you could find other uses for this function. + + @param ranges A tensor containing the symbol indexes for each frame. + Its shape is (B, T, s_range). See the docs in `get_rnnt_prune_ranges` + in k2/python/k2/rnnt_loss.py for more details of this tensor. + Its type is int32, consistent with that in rnnt_loss.py. + @param frames The number of frames per sample with shape (B). + Its type is int32. + @param symbols The symbol sequence, a LongTensor of shape (B, S), + and elements in {0..C-1}. + Its type is int64(Long), consistent with that in rnnt_loss.py. + @param logits The pruned joiner network (or am/lm) + of shape (B, T, s_range, C). + Its type can be float32, float64, float16. Though float32 is mainly + used, float64 and float16 are also supported for future use. + @param [out] arc_map A map from arcs in generated lattice to global index + of logits, or -1 if the arc had no corresponding score in logits, + e.g. arc pointing to super final state. + @return Return an FsaVec, which contains the generated lattice. +*/ +FsaVec PrunedRangesToLattice( + torch::Tensor ranges, // [B][T][s_range] + torch::Tensor frames, // [B] + torch::Tensor symbols, // [B][S] + torch::Tensor logits, // [B][S][s_range][C] + Array1 *arc_map) { + + TORCH_CHECK(ranges.get_device() == frames.get_device()); + TORCH_CHECK(ranges.get_device() == symbols.get_device()); + TORCH_CHECK(ranges.get_device() == logits.get_device()); + + TORCH_CHECK(ranges.dim() == 3, "ranges should be 3-dimensional"); + TORCH_CHECK(frames.dim() == 1, "frames should be 1-dimensional"); + TORCH_CHECK(symbols.dim() == 2, "symbols should be 2-dimensional"); + TORCH_CHECK(logits.dim() == 4, "logits should be 4-dimensional"); + + TORCH_CHECK(torch::kInt == ranges.scalar_type()); + TORCH_CHECK(torch::kInt == frames.scalar_type()); + TORCH_CHECK(torch::kLong == symbols.scalar_type()); + + ContextPtr context; + if (ranges.device().type() == torch::kCPU) { + context = GetCpuContext(); + } else if (ranges.is_cuda()) { + context = GetCudaContext(ranges.device().index()); + } else { + K2_LOG(FATAL) << "Unsupported device: " << ranges.device() + << "\nOnly CPU and CUDA are verified"; + } + + // "_a" is short for accessor. + auto ranges_a = ranges.accessor(); + auto frames_a = frames.accessor(); + auto symbols_a = symbols.accessor(); + + // Typically, s_range is 5. + const int32_t B = ranges.size(0), + T = ranges.size(1), + s_range = ranges.size(2); + + // Compute f2s_shape: fsa_to_state_shape. + Array1 f2s_row_splits(context, B + 1); + int32_t * f2s_row_splits_data = f2s_row_splits.Data(); + K2_EVAL(context, B, lambda_set_f2s_row_splits, (int32_t fsa_idx0) { + int32_t t = frames_a[fsa_idx0]; + K2_CHECK_LE(t, T); + // + 1 in "t * s_range + 1" is for super-final state. + f2s_row_splits_data[fsa_idx0] = t * s_range + 1; + }); + + ExclusiveSum(f2s_row_splits, &f2s_row_splits); + RaggedShape f2s_shape = + RaggedShape2(&f2s_row_splits, nullptr, -1); + + // Compute s2a_shape: state_to_arc_shape. + int32_t num_states = f2s_shape.NumElements(); + Array1 s2c_row_splits(context, num_states + 1); + int32_t *s2c_row_splits_data = s2c_row_splits.Data(); + const int32_t *f2s_row_splits1_data = f2s_shape.RowSplits(1).Data(), + *f2s_row_ids1_data = f2s_shape.RowIds(1).Data(); + // Compute number of arcs for each state. + K2_EVAL( + context, num_states, lambda_set_num_arcs, (int32_t state_idx01)->void { + int32_t fsa_idx0 = f2s_row_ids1_data[state_idx01], + state_idx0x = f2s_row_splits1_data[fsa_idx0], + state_idx1 = state_idx01 - state_idx0x, + t = state_idx1 / s_range, + token_idx = state_idx1 % s_range; + + K2_CHECK_LE(t, frames_a[fsa_idx0]); + + // The state doesn't have leaving arc: super final_state. + if (state_idx1 == frames_a[fsa_idx0] * s_range) { + s2c_row_splits_data[state_idx01] = 0; + return; + } + + // States have a leaving arc if no specially processed. + s2c_row_splits_data[state_idx01] = 1; + + // States have two leaving arcs. + bool has_horizontal_blank_arc = false; + if (t < frames_a[fsa_idx0] - 1) { + has_horizontal_blank_arc = + ranges_a[fsa_idx0][t][token_idx] >= ranges_a[fsa_idx0][t + 1][0]; + } + // Typically, s_range == 5, i.e. 5 states for each time step. + // the 4th state(0-based index) ONLY has a horizontal blank arc. + // While state [0, 1, 2, 3] have a vertial arc + // and MAYBE a horizontal blank arc. + if (token_idx != s_range - 1 && has_horizontal_blank_arc) { + s2c_row_splits_data[state_idx01] = 2; + } + }); + ExclusiveSum(s2c_row_splits, &s2c_row_splits); + RaggedShape s2a_shape = + RaggedShape2(&s2c_row_splits, nullptr, -1); + + // ofsa_shape: output_fsa_shape. + RaggedShape ofsa_shape = ComposeRaggedShapes(f2s_shape, s2a_shape); + + int32_t num_arcs = ofsa_shape.NumElements(); + Array1 arcs(context, num_arcs); + + Arc *arcs_data = arcs.Data(); + const int32_t *row_splits1_data = ofsa_shape.RowSplits(1).Data(), + *row_ids1_data = ofsa_shape.RowIds(1).Data(), + *row_splits2_data = ofsa_shape.RowSplits(2).Data(), + *row_ids2_data = ofsa_shape.RowIds(2).Data(); + + Array1 out_map(context, num_arcs); + int32_t* out_map_data = out_map.Data(); + // Used to populate out_map. + const int32_t lg_stride_0 = logits.stride(0), + lg_stride_1 = logits.stride(1), + lg_stride_2 = logits.stride(2), + lg_stride_3 = logits.stride(3); + + // Type of logits can be float32, float64, float16. Though float32 is mainly + // used, float64 and float16 are also supported for future use. + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + logits.scalar_type(), "pruned_ranges_to_lattice", ([&] { + auto logits_a = logits.accessor(); + + K2_EVAL( + context, num_arcs, lambda_set_arcs, (int32_t arc_idx012)->void { + const int32_t state_idx01 = row_ids2_data[arc_idx012], + fsa_idx0 = row_ids1_data[state_idx01], + state_idx0x = row_splits1_data[fsa_idx0], + arc_idx01x = row_splits2_data[state_idx01], + state_idx1 = state_idx01 - state_idx0x, + arc_idx2 = arc_idx012 - arc_idx01x, + t = state_idx1 / s_range, + // token_idx lies within interval [0, s_range) + // but does not include s_range. + token_idx = state_idx1 % s_range; + Arc arc; + arc.src_state = state_idx1; + // The penultimate state only has a leaving arc + // pointing to super final state. + if (state_idx1 == frames_a[fsa_idx0] * s_range - 1) { + arc.src_state = state_idx1; + arc.dest_state = state_idx1 + 1; + arc.label = -1; + arc.score = 0.0; + arcs_data[arc_idx012] = arc; + out_map_data[arc_idx012] = -1; + return; + } + + if (token_idx < s_range - 1) { + // States have a vertal arc with non-blank label and + // MAYBE a horizontal arc with blank label. + const int32_t symbol_idx = ranges_a[fsa_idx0][t][token_idx], + arc_label = symbols_a[fsa_idx0][symbol_idx]; + K2_CHECK_LE(arc_idx2, 2); + switch (arc_idx2) { + // For vertial arc with non-blank label. + case 0: + arc.dest_state = state_idx1 + 1; + arc.label = arc_label; + arc.score = logits_a[fsa_idx0][t][token_idx][arc_label]; + + out_map_data[arc_idx012] = + fsa_idx0 * lg_stride_0 + t * lg_stride_1 + + token_idx * lg_stride_2 + arc_label * lg_stride_3; + break; + // For horizontal arc with blank label. + case 1: + const int32_t dest_state_token_idx = + ranges_a[fsa_idx0][t][token_idx] - + ranges_a[fsa_idx0][t + 1][0]; + K2_CHECK_GE(dest_state_token_idx, 0); + arc.dest_state = dest_state_token_idx + (t + 1) * s_range; + arc.label = 0; + arc.score = logits_a[fsa_idx0][t][token_idx][0]; + + out_map_data[arc_idx012] = + fsa_idx0 * lg_stride_0 + t * lg_stride_1 + + token_idx * lg_stride_2; + break; + } + } else { + // States only have a horizontal arc with blank label. + K2_CHECK_EQ(arc_idx2, 0); + const int32_t dest_state_token_idx = + ranges_a[fsa_idx0][t][token_idx] - + ranges_a[fsa_idx0][t + 1][0]; + arc.dest_state = + dest_state_token_idx + (t + 1) * s_range; + arc.label = 0; + arc.score = logits_a[fsa_idx0][t][token_idx][0]; + + out_map_data[arc_idx012] = + fsa_idx0 * lg_stride_0 + t * lg_stride_1 + + token_idx * lg_stride_2; + } + arcs_data[arc_idx012] = arc; + }); + *arc_map = std::move(out_map); + })); + + return Ragged(ofsa_shape, arcs); +} + +} // namespace k2 + +void PybindPrunedRangesToLattice(py::module &m) { + m.def( + "pruned_ranges_to_lattice", + [](torch::Tensor ranges, + torch::Tensor frames, + torch::Tensor symbols, + torch::Tensor logits) -> std::pair { + k2::DeviceGuard guard(k2::GetContext(ranges)); + k2::Array1 arc_to_logit_map; + k2::FsaVec ofsa = k2::PrunedRangesToLattice( + ranges, frames, symbols, logits, &arc_to_logit_map); + return std::make_pair(ofsa, ToTorch(arc_to_logit_map)); + }, + py::arg("ranges"), py::arg("frames"), + py::arg("symbols"), py::arg("logits")); +} diff --git a/k2/python/csrc/torch/pruned_ranges_to_lattice.h b/k2/python/csrc/torch/pruned_ranges_to_lattice.h new file mode 100644 index 000000000..0e02c3529 --- /dev/null +++ b/k2/python/csrc/torch/pruned_ranges_to_lattice.h @@ -0,0 +1,67 @@ +/** + * @copyright + * Copyright 2022 Xiaomi Corporation (authors: Liyong Guo) + * + * @copyright + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef K2_PYTHON_CSRC_TORCH_PRUNED_RANGES_TO_LATTICE_H_ +#define K2_PYTHON_CSRC_TORCH_PRUNED_RANGES_TO_LATTICE_H_ + +#include "k2/python/csrc/torch.h" + +namespace k2 { + +/* + Convert pruned ranges to lattice while also supporting autograd. + + The input pruned ranges is normally generated by `get_rnnt_prune_ranges`. + See k2/python/k2/rnnt_loss.py for the process of generating ranges and + the information it represents. + + When this is implemented, the lattice is used to generate force-alignment. + Perhaps you could find other uses for this function. + + @param ranges A tensor containing the symbol indexes for each frame. + Its shape is (B, T, s_range). See the docs in `get_rnnt_prune_ranges` + in k2/python/k2/rnnt_loss.py for more details of this tensor. + Its type is int32, consistent with that in rnnt_loss.py. + @param frames The number of frames per sample with shape (B). + Its type is int32. + @param symbols The symbol sequence, a LongTensor of shape (B, S), + and elements in {0..C-1}. + Its type is int64(Long), consistent with that in rnnt_loss.py. + @param logits The pruned joiner network (or am/lm) + of shape (B, T, s_range, C). + Its type can be float32, float64, float16. Though float32 is mainly + used, float64 and float16 are also supported for future use. + @param [out] arc_map A map from arcs in generated lattice to global index + of logits, or -1 if the arc had no corresponding score in logits, + e.g. arc pointing to super final state. + @return Return an FsaVec, which contains the generated lattice. +*/ +FsaVec PrunedRangesToLattice( + torch::Tensor ranges, // [B][T][s_range] + torch::Tensor frames, // [B] + torch::Tensor symbols, // [B][S] + torch::Tensor logits, // [B][T][s_range][C] + Array1 *arc_map); + +} // namespace k2 + +void PybindPrunedRangesToLattice(py::module &m); + +#endif // K2_PYTHON_CSRC_TORCH_PRUNED_RANGES_TO_LATTICE_H_ diff --git a/k2/python/k2/__init__.py b/k2/python/k2/__init__.py index fb29fb29a..5d0691c1e 100644 --- a/k2/python/k2/__init__.py +++ b/k2/python/k2/__init__.py @@ -119,6 +119,7 @@ from .utils import random_fsa from .utils import random_fsa_vec from _k2.version import with_cuda +from _k2 import pruned_ranges_to_lattice from .decode import get_aux_labels from .decode import get_lattice diff --git a/k2/python/k2/fsa.py b/k2/python/k2/fsa.py index 8b0067998..965ed5ce2 100644 --- a/k2/python/k2/fsa.py +++ b/k2/python/k2/fsa.py @@ -453,7 +453,7 @@ def properties(self) -> int: " fsa.labels = labels" ) return properties # Return cached properties. - + self.labels_version = self.labels._version if self.arcs.num_axes() == 2: properties = _k2.get_fsa_basic_properties(self.arcs) diff --git a/k2/python/tests/CMakeLists.txt b/k2/python/tests/CMakeLists.txt index 429af80b5..400ab930f 100644 --- a/k2/python/tests/CMakeLists.txt +++ b/k2/python/tests/CMakeLists.txt @@ -60,6 +60,7 @@ set(py_test_files nbest_test.py numerical_gradient_check_test.py online_dense_intersecter_test.py + pruned_ranges_to_lattice_test.py ragged_ops_test.py ragged_shape_test.py ragged_tensor_test.py diff --git a/k2/python/tests/pruned_ranges_to_lattice_test.py b/k2/python/tests/pruned_ranges_to_lattice_test.py new file mode 100644 index 000000000..ee95f7e00 --- /dev/null +++ b/k2/python/tests/pruned_ranges_to_lattice_test.py @@ -0,0 +1,210 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (authors: Liyong Guo) +# +# See ../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# To run this single test, use +# +# ctest --verbose -R pruned_ranges_to_lattice_test_py + +import unittest + +import k2 +import torch + + +class TestPrunedRangesToLattice(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.devices = [torch.device("cpu")] + if torch.cuda.is_available() and k2.with_cuda: + cls.devices.append(torch.device("cuda", 0)) + if torch.cuda.device_count() > 1: + torch.cuda.set_device(1) + cls.devices.append(torch.device("cuda", 1)) + cls.float_dtypes = [torch.float32, torch.float64, torch.float16] + + def _common_test_part(self, ranges, frames, symbols, logits): + ofsa, arc_map = k2.pruned_ranges_to_lattice( + ranges, frames, symbols, logits + ) + + assert torch.equal( + arc_map, + torch.tensor( + [ + 8, + 16, + 9, + 24, + 18, + 32, + 27, + 36, + 52, + 60, + 54, + 68, + 63, + 76, + 72, + 81, + 96, + 104, + 99, + 112, + 108, + 120, + 117, + 126, + 140, + 148, + 156, + 164, + -1, + 182, + 180, + 192, + 189, + 202, + 198, + 212, + 207, + 216, + 227, + 237, + 247, + 257, + 252, + 261, + 275, + 285, + 295, + 305, + -1, + ], + dtype=torch.int32, + ), + ) + lattice = k2.Fsa(ofsa) + + scores_tracked_by_autograd = k2.index_select( + logits.reshape(-1).to(torch.float32), arc_map + ) + + assert torch.allclose( + lattice.scores.to(torch.float32), scores_tracked_by_autograd + ) + + assert torch.equal( + lattice.arcs.values()[:, :3], + torch.tensor( + [ + [0, 1, 8], + [1, 2, 7], + [1, 5, 0], + [2, 3, 6], + [2, 6, 0], + [3, 4, 5], + [3, 7, 0], + [4, 8, 0], + [5, 6, 7], + [6, 7, 6], + [6, 10, 0], + [7, 8, 5], + [7, 11, 0], + [8, 9, 4], + [8, 12, 0], + [9, 13, 0], + [10, 11, 6], + [11, 12, 5], + [11, 15, 0], + [12, 13, 4], + [12, 16, 0], + [13, 14, 3], + [13, 17, 0], + [14, 18, 0], + [15, 16, 5], + [16, 17, 4], + [17, 18, 3], + [18, 19, 2], + [19, 20, -1], + [0, 1, 2], + [0, 5, 0], + [1, 2, 3], + [1, 6, 0], + [2, 3, 4], + [2, 7, 0], + [3, 4, 5], + [3, 8, 0], + [4, 9, 0], + [5, 6, 2], + [6, 7, 3], + [7, 8, 4], + [8, 9, 5], + [8, 10, 0], + [9, 11, 0], + [10, 11, 5], + [11, 12, 6], + [12, 13, 7], + [13, 14, 8], + [14, 15, -1], + ], + dtype=torch.int32, + ), + ) + + def test(self): + ranges = torch.tensor( + [ + [ + [0, 1, 2, 3, 4], + [1, 2, 3, 4, 5], + [2, 3, 4, 5, 6], + [3, 4, 5, 6, 7], + ], + [ + [0, 1, 2, 3, 4], + [0, 1, 2, 3, 4], + [3, 4, 5, 6, 7], + [3, 4, 5, 6, 7], + ], + ], + dtype=torch.int32, + ) + B, T, s_range = ranges.size() + C = 9 + + frames = torch.tensor([4, 3], dtype=torch.int32) + symbols = torch.tensor( + [[8, 7, 6, 5, 4, 3, 2], [2, 3, 4, 5, 6, 7, 8]], dtype=torch.long + ) + logits = torch.tensor( + [0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8], dtype=torch.float32 + ).expand(B, T, s_range, C) + logits = logits + torch.tensor( + [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0] + ).reshape(B, T, 1, 1) + logits = logits + torch.tensor([0.0, 1, 2, 3, 4]).reshape( + 1, 1, s_range, 1 + ) + for dtype in self.float_dtypes: + tmp_logits = logits.to(dtype) + self._common_test_part(ranges, frames, symbols, tmp_logits) + + +if __name__ == "__main__": + unittest.main()