From 32fbbba2bf7d3e7e2e73f9ae32d73cdb42eb1120 Mon Sep 17 00:00:00 2001 From: Guo Liyong Date: Sat, 17 Dec 2022 00:53:39 +0800 Subject: [PATCH 01/18] draft version --- k2/csrc/CMakeLists.txt | 1 + k2/csrc/self_alignment.cu | 220 +++++++++++++++++++++++++ k2/csrc/self_alignment.h | 45 +++++ k2/python/csrc/torch.cu | 2 + k2/python/csrc/torch/CMakeLists.txt | 1 + k2/python/csrc/torch/self_alignment.cu | 47 ++++++ k2/python/csrc/torch/self_alignment.h | 28 ++++ k2/python/k2/__init__.py | 1 + 8 files changed, 345 insertions(+) create mode 100644 k2/csrc/self_alignment.cu create mode 100644 k2/csrc/self_alignment.h create mode 100644 k2/python/csrc/torch/self_alignment.cu create mode 100644 k2/python/csrc/torch/self_alignment.h diff --git a/k2/csrc/CMakeLists.txt b/k2/csrc/CMakeLists.txt index 736668e9b..b61ee8eb8 100644 --- a/k2/csrc/CMakeLists.txt +++ b/k2/csrc/CMakeLists.txt @@ -74,6 +74,7 @@ set(context_srcs reverse.cu rm_epsilon.cu rnnt_decode.cu + self_alignment.cu tensor.cu tensor_ops.cu thread_pool.cu diff --git a/k2/csrc/self_alignment.cu b/k2/csrc/self_alignment.cu new file mode 100644 index 000000000..1f6d4e3a5 --- /dev/null +++ b/k2/csrc/self_alignment.cu @@ -0,0 +1,220 @@ +/** + * @copyright + * Copyright 2022 Xiaomi Corporation (authors: Liyong Guo) + * + * @copyright + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef K2_CSRC_SELF_ALIGNMENT_H_ +#define K2_CSRC_SELF_ALIGNMENT_H_ + +#include + +#include + +#include "k2/python/csrc/torch.h" + +namespace k2 { + +FsaVec SelfAlignment( + torch::Tensor ranges, // [B][S][T+1] if !modified, [B][S][T] if modified. + torch::Tensor x_lens, // [B][S+1][T] + torch::Tensor blank_connections, + torch::Tensor y, + // const Ragged &y, + torch::optional boundary, // [B][4], int64_t. + torch::Tensor p, + Array1 *arc_map) { + ContextPtr context; + if (ranges.device().type() == torch::kCPU) { + context = GetCpuContext(); + } else if (ranges.is_cuda()) { + context = GetCudaContext(ranges.device().index()); + } else { + K2_LOG(FATAL) << "Unsupported device: " << ranges.device() + << "\nOnly CPU and CUDA are supported"; + } + + TORCH_CHECK(ranges.dim() == 3, "ranges must be 3-dimensional"); + // U is always 5 + const int32_t B = ranges.size(0), T = ranges.size(1), U = ranges.size(2); + + + Array1 out_map(context, 1); + *arc_map = std::move(out_map); + K2_LOG(INFO) << "hello" << B << T << U; + // K2_LOG(INFO) << ranges[0][5]; + // K2_LOG(INFO) << ranges[0]; + Dtype t = ScalarTypeToDtype(ranges.scalar_type()); + K2_CHECK_EQ(kInt64Dtype, t); + K2_CHECK_EQ(torch::kLong, ranges.scalar_type()); + K2_CHECK_EQ(torch::kInt, x_lens.scalar_type()); // int32_t + // Dtype t = ScalarTypeToDtype(ranges.GetDtype()); + // std::is_same::value; + // static_assert(std::is_same::value, "ranges is not kInt64Dtype") + // const int32_t *row_ids_data = row_ids.data_ptr(); + // FOR_REAL_TYPES(t, t, { + const int64_t *ranges_data = ranges.data_ptr(); + const int32_t *x_lens_data = x_lens.data_ptr(); + const int32_t *blank_connections_data = blank_connections.data_ptr(); + const int32_t *y_data = y.data_ptr(); + int32_t stride_0 = ranges.stride(0), + stride_1 = ranges.stride(1), + stride_2 = ranges.stride(2); + int32_t blk_stride_0 = blank_connections.stride(0), + blk_stride_1 = blank_connections.stride(1), + blk_stride_2 = blank_connections.stride(2); + K2_LOG(INFO) << "stride_0: " << stride_0; + K2_LOG(INFO) << "stride_1: " << stride_1; + K2_LOG(INFO) << "stride_2: " << stride_2; + int64_t numel = ranges.numel(); + Array1 re_ranges(context, numel); + // f2s: fsa to states + K2_CHECK_EQ(x_lens.numel(), B); + Array1 f2s_row_splits(context, B + 1); + int32_t * f2s_row_splits_data = f2s_row_splits.Data(); + K2_EVAL(context, B, lambda_set_f2s_row_splits, (int32_t fsa_idx0) { + // f2s_row_splits_data[0] = 0; + int32_t t = x_lens_data[fsa_idx0]; + // + 1 in "t * U + 1" is for super-final state. + f2s_row_splits_data[fsa_idx0] = t * U + 1; + }); + ExclusiveSum(f2s_row_splits, &f2s_row_splits); + + RaggedShape fsa_to_states = + RaggedShape2(&f2s_row_splits, nullptr, -1); + int32_t num_states = fsa_to_states.NumElements(); + Array1 s2c_row_splits(context, num_states + 1); + int32_t *s2c_row_splits_data = s2c_row_splits.Data(); + const int32_t *fts_row_splits1_data = fsa_to_states.RowSplits(1).Data(), + *fts_row_ids1_data = fsa_to_states.RowIds(1).Data(); + + // set the arcs number for each state + K2_EVAL( + context, num_states, lambda_set_num_arcs, (int32_t state_idx01)->void { + int32_t fsa_idx0 = fts_row_ids1_data[state_idx01], + state_idx0x = fts_row_splits1_data[fsa_idx0], + state_idx0x_next = fts_row_splits1_data[fsa_idx0 + 1], + state_idx1 = state_idx01 - state_idx0x, + t = state_idx1 / U, + token_index = state_idx1 % U; + if (state_idx1 == x_lens_data[fsa_idx0] * U) { + // final arc to super final state. + s2c_row_splits_data[state_idx01] = 1; + return; + } + int32_t range_offset = fsa_idx0 * stride_0 + t * stride_1 + token_index * stride_2; + int32_t blank_connections_data_offset = fsa_idx0 * blk_stride_0 + t * blk_stride_1 + token_index * blk_stride_2; + int32_t next_state_idx1 = blank_connections_data[blank_connections_data_offset]; + if (token_index < U - 1) { + s2c_row_splits_data[state_idx01] = 1; + if (next_state_idx1 >= 0) { + s2c_row_splits_data[state_idx01] = 2; + } + } else { + s2c_row_splits_data[state_idx01] = 0; + if (next_state_idx1 >= 0) { + s2c_row_splits_data[state_idx01] = 1; + } + } + }); + + ExclusiveSum(s2c_row_splits, &s2c_row_splits); + RaggedShape states_to_arcs = + RaggedShape2(&s2c_row_splits, nullptr, -1); + + RaggedShape ofsa_shape = ComposeRaggedShapes(fsa_to_states, states_to_arcs); + int32_t num_arcs = ofsa_shape.NumElements(); + Array1 arcs(context, num_arcs); + Arc *arcs_data = arcs.Data(); + const int32_t *row_splits1_data = ofsa_shape.RowSplits(1).Data(), + *row_ids1_data = ofsa_shape.RowIds(1).Data(), + *row_splits2_data = ofsa_shape.RowSplits(2).Data(), + *row_ids2_data = ofsa_shape.RowIds(2).Data(); + + // auto y_shape = y.shape; + // const int32_t * y_data = y.values.Data(); + int32_t y_stride_0 = y.stride(0), + y_stride_1 = y.stride(1); + K2_EVAL( + context, num_arcs, lambda_set_arcs, (int32_t arc_idx012)->void { + int32_t state_idx01 = row_ids2_data[arc_idx012], + fsa_idx0 = row_ids1_data[state_idx01], + state_idx0x = row_splits1_data[fsa_idx0], + state_idx0x_next = row_splits1_data[fsa_idx0 + 1], + arc_idx01x = row_splits2_data[state_idx01], + state_idx1 = state_idx01 - state_idx0x, + arc_idx2 = arc_idx012 - arc_idx01x, + t = state_idx1 / U, + token_index = state_idx1 % U; // token_index is belong to [0, U) + Arc arc; + if (state_idx1 == x_lens_data[fsa_idx0] * U) { + arc.src_state = state_idx1; + arc.dest_state = state_idx1 + 1; + arc.label = -1; + arc.score = 0.0; + arcs_data[arc_idx012] = arc; + return; + } + int32_t rangged_offset = fsa_idx0 * stride_0 + t * stride_1 + token_index * stride_2; + int32_t actual_u = ranges_data[rangged_offset]; + int32_t y_offset = fsa_idx0 * y_stride_0 + actual_u; + int32_t arc_label = y_data[y_offset]; + int32_t blank_connections_data_offset, next_state_idx1; + arc.src_state = state_idx1; + // arc. + if (token_index < U - 1) { + switch (arc_idx2) { + case 0: + arc.dest_state = state_idx1 + 1; + arc.label = arc_label; + break; + case 1: + blank_connections_data_offset = fsa_idx0 * blk_stride_0 + t * blk_stride_1 + token_index * blk_stride_2; + next_state_idx1 = blank_connections_data[blank_connections_data_offset]; + K2_CHECK_GE(next_state_idx1, 0); + arc.dest_state = next_state_idx1 + (t + 1) * U; + arc.label = 0; + break; + default: + K2_LOG(FATAL) << "Arc index must be less than 3"; + } + } else { + K2_CHECK_EQ(arc_idx2, 0); + blank_connections_data_offset = fsa_idx0 * blk_stride_0 + t * blk_stride_1 + token_index * blk_stride_2; + next_state_idx1 = blank_connections_data[blank_connections_data_offset]; + K2_CHECK_GE(next_state_idx1, 0); + arc.dest_state = next_state_idx1 + (t + 1) * U; + arc.label = 0; + } + arcs_data[arc_idx012] = arc; + }); + return Ragged(ofsa_shape, arcs); + + + + int64_t * re_reanges_data = re_ranges.Data(); + K2_LOG(INFO) << "numel: " << numel; + K2_EVAL(context, numel, lambda_set_range, (int32_t i) { + re_reanges_data[i] = ranges_data[i]; + }); + auto cpu_range = re_ranges.To(GetCpuContext()); + +} + +} // namespace k2 + +#endif // K2_CSRC_SELF_ALIGNMENT_H_ diff --git a/k2/csrc/self_alignment.h b/k2/csrc/self_alignment.h new file mode 100644 index 000000000..888fc435a --- /dev/null +++ b/k2/csrc/self_alignment.h @@ -0,0 +1,45 @@ +/** + * @copyright + * Copyright 2022 Xiaomi Corporation (authors: Liyong Guo) + * + * @copyright + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef K2_CSRC_SELF_ALIGNMENT_H_ +#define K2_CSRC_SELF_ALIGNMENT_H_ + +#include + +#include + +#include "k2/python/csrc/torch.h" + +namespace k2 { + +FsaVec SelfAlignment( + torch::Tensor ranges, // [B][S][T+1] if !modified, [B][S][T] if modified. + torch::Tensor x_lens, // [B][S+1][T] + torch::Tensor blank_connections, + torch::Tensor y, + // const Ragged &y, + torch::optional boundary, // [B][4], int64_t. + torch::Tensor p, + // FsaVec * ofsa, + Array1 *arc_map); + +} // namespace k2 + +#endif // K2_CSRC_SELF_ALIGNMENT_H_ diff --git a/k2/python/csrc/torch.cu b/k2/python/csrc/torch.cu index f62dfc148..0b35d04b3 100644 --- a/k2/python/csrc/torch.cu +++ b/k2/python/csrc/torch.cu @@ -34,6 +34,7 @@ #include "k2/python/csrc/torch/ragged.h" #include "k2/python/csrc/torch/ragged_ops.h" #include "k2/python/csrc/torch/rnnt_decode.h" +#include "k2/python/csrc/torch/self_alignment.h" #include "k2/python/csrc/torch/v2/k2.h" void PybindTorch(py::module &m) { @@ -47,6 +48,7 @@ void PybindTorch(py::module &m) { PybindRagged(m); PybindRaggedOps(m); PybindRnntDecode(m); + PybindSelfAlignment(m); k2::PybindV2(m); } diff --git a/k2/python/csrc/torch/CMakeLists.txt b/k2/python/csrc/torch/CMakeLists.txt index 100ad8ce0..8be261ca5 100644 --- a/k2/python/csrc/torch/CMakeLists.txt +++ b/k2/python/csrc/torch/CMakeLists.txt @@ -1,5 +1,6 @@ # please keep the list sorted set(torch_srcs + self_alignment.cu arc.cu fsa.cu fsa_algo.cu diff --git a/k2/python/csrc/torch/self_alignment.cu b/k2/python/csrc/torch/self_alignment.cu new file mode 100644 index 000000000..f3de61cce --- /dev/null +++ b/k2/python/csrc/torch/self_alignment.cu @@ -0,0 +1,47 @@ +/** + * @copyright + * Copyright 2022 Xiaomi Corporation (authors: Liyong Guo) + * + * @copyright + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "k2/csrc/device_guard.h" +#include "k2/csrc/fsa.h" +#include "k2/csrc/torch_util.h" +#include "k2/csrc/self_alignment.h" +#include "k2/python/csrc/torch/self_alignment.h" +#include "k2/python/csrc/torch/v2/ragged_any.h" + + +void PybindSelfAlignment(py::module &m) { + m.def( + "self_alignment", + [](torch::Tensor ranges, torch::Tensor x_lens, + torch::Tensor blank_connections, + torch::Tensor y, + torch::optional boundary, + torch::Tensor p) -> std::pair { + k2::DeviceGuard guard(k2::GetContext(ranges)); + k2::Array1 label_map; + k2::FsaVec ofsa = k2::SelfAlignment(ranges, x_lens, blank_connections, y, boundary, p, &label_map); + // k2::SelfAlignment(ranges, x_lens, blank_connections, y.any.Specialize(), boundary, p, &ofsa, &label_map); + // k2::SelfAlignment(ranges, x_lens, blank_connections, boundary, p, &ofsa, &label_map); + torch::Tensor tensor = ToTorch(label_map); + return std::make_pair(ofsa, tensor); + }, + py::arg("ranges"), py::arg("x_lens"), py::arg("blank_connections"), py::arg("y"), py::arg("boundary"), py::arg("p")); + // py::arg("ranges"), py::arg("x_lens"), py::arg("blank_connections"), py::arg("boundary"), py::arg("p")); +} diff --git a/k2/python/csrc/torch/self_alignment.h b/k2/python/csrc/torch/self_alignment.h new file mode 100644 index 000000000..30033979a --- /dev/null +++ b/k2/python/csrc/torch/self_alignment.h @@ -0,0 +1,28 @@ +/** + * @copyright + * Copyright 2022 Xiaomi Corporation (authors: Liyong Guo) + * + * @copyright + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef K2_PYTHON_CSRC_TORCH_SELF_ALIGNMENT_H_ +#define K2_PYTHON_CSRC_TORCH_SELF_ALIGNMENT_H_ + +#include "k2/python/csrc/torch.h" + +void PybindSelfAlignment(py::module &m); + +#endif // K2_PYTHON_CSRC_TORCH_SELF_ALIGNMENT_H_ diff --git a/k2/python/k2/__init__.py b/k2/python/k2/__init__.py index f4e04be10..cc9e2d612 100644 --- a/k2/python/k2/__init__.py +++ b/k2/python/k2/__init__.py @@ -122,6 +122,7 @@ from .utils import random_fsa from .utils import random_fsa_vec from _k2.version import with_cuda +from _k2 import self_alignment from .decode import get_aux_labels from .decode import get_lattice From 5ef209203e8f6d9451a9fde493979bdccbcddd23 Mon Sep 17 00:00:00 2001 From: Guo Liyong Date: Mon, 19 Dec 2022 19:31:05 +0800 Subject: [PATCH 02/18] version 1 --- k2/csrc/self_alignment.cu | 119 +++++++++++++++++-------- k2/csrc/self_alignment.h | 3 +- k2/python/csrc/torch/self_alignment.cu | 7 +- 3 files changed, 86 insertions(+), 43 deletions(-) diff --git a/k2/csrc/self_alignment.cu b/k2/csrc/self_alignment.cu index 1f6d4e3a5..c9a17023a 100644 --- a/k2/csrc/self_alignment.cu +++ b/k2/csrc/self_alignment.cu @@ -30,13 +30,14 @@ namespace k2 { FsaVec SelfAlignment( - torch::Tensor ranges, // [B][S][T+1] if !modified, [B][S][T] if modified. - torch::Tensor x_lens, // [B][S+1][T] + // Normally, ranges is with shape [B][S][T+1] if !modified, [B][S][T] if modified. + // Currently, only [B][S][T] is supported. + torch::Tensor ranges, + torch::Tensor x_lens, // [B][T] torch::Tensor blank_connections, torch::Tensor y, // const Ragged &y, - torch::optional boundary, // [B][4], int64_t. - torch::Tensor p, + torch::Tensor logits, Array1 *arc_map) { ContextPtr context; if (ranges.device().type() == torch::kCPU) { @@ -53,42 +54,45 @@ FsaVec SelfAlignment( const int32_t B = ranges.size(0), T = ranges.size(1), U = ranges.size(2); - Array1 out_map(context, 1); - *arc_map = std::move(out_map); - K2_LOG(INFO) << "hello" << B << T << U; // K2_LOG(INFO) << ranges[0][5]; // K2_LOG(INFO) << ranges[0]; Dtype t = ScalarTypeToDtype(ranges.scalar_type()); - K2_CHECK_EQ(kInt64Dtype, t); - K2_CHECK_EQ(torch::kLong, ranges.scalar_type()); + // K2_CHECK_EQ(kInt64Dtype, t); + // K2_CHECK_EQ(torch::kLong, ranges.scalar_type()); + K2_CHECK_EQ(torch::kInt, ranges.scalar_type()); K2_CHECK_EQ(torch::kInt, x_lens.scalar_type()); // int32_t // Dtype t = ScalarTypeToDtype(ranges.GetDtype()); // std::is_same::value; // static_assert(std::is_same::value, "ranges is not kInt64Dtype") // const int32_t *row_ids_data = row_ids.data_ptr(); // FOR_REAL_TYPES(t, t, { - const int64_t *ranges_data = ranges.data_ptr(); + const float *logits_data = logits.data_ptr(); + const int32_t *ranges_data = ranges.data_ptr(); const int32_t *x_lens_data = x_lens.data_ptr(); const int32_t *blank_connections_data = blank_connections.data_ptr(); const int32_t *y_data = y.data_ptr(); - int32_t stride_0 = ranges.stride(0), - stride_1 = ranges.stride(1), - stride_2 = ranges.stride(2); + int32_t rng_stride_0 = ranges.stride(0), + rng_stride_1 = ranges.stride(1), + rng_stride_2 = ranges.stride(2); int32_t blk_stride_0 = blank_connections.stride(0), blk_stride_1 = blank_connections.stride(1), blk_stride_2 = blank_connections.stride(2); - K2_LOG(INFO) << "stride_0: " << stride_0; - K2_LOG(INFO) << "stride_1: " << stride_1; - K2_LOG(INFO) << "stride_2: " << stride_2; - int64_t numel = ranges.numel(); - Array1 re_ranges(context, numel); + int32_t lg_stride_0 = logits.stride(0), + lg_stride_1 = logits.stride(1), + lg_stride_2 = logits.stride(2), + lg_stride_3 = logits.stride(3); + // K2_LOG(INFO) << "stride_0: " << stride_0; + // K2_LOG(INFO) << "stride_1: " << stride_1; + // K2_LOG(INFO) << "stride_2: " << stride_2; + // int32_t numel = ranges.numel(); + // Array1 re_ranges(context, numel); // f2s: fsa to states K2_CHECK_EQ(x_lens.numel(), B); Array1 f2s_row_splits(context, B + 1); int32_t * f2s_row_splits_data = f2s_row_splits.Data(); K2_EVAL(context, B, lambda_set_f2s_row_splits, (int32_t fsa_idx0) { - // f2s_row_splits_data[0] = 0; int32_t t = x_lens_data[fsa_idx0]; + K2_CHECK_LE(t, T); // + 1 in "t * U + 1" is for super-final state. f2s_row_splits_data[fsa_idx0] = t * U + 1; }); @@ -111,24 +115,45 @@ FsaVec SelfAlignment( state_idx1 = state_idx01 - state_idx0x, t = state_idx1 / U, token_index = state_idx1 % U; - if (state_idx1 == x_lens_data[fsa_idx0] * U) { + + K2_CHECK_LE(t, x_lens_data[fsa_idx0]); + if (state_idx1 == x_lens_data[fsa_idx0] * U - 1) { + // x_lens[fsa_idx0] * U is the state_idx1 of super final-state. + // x_lens[fsa_idx0] * U - 1 is the state pointing to super final-state. // final arc to super final state. s2c_row_splits_data[state_idx01] = 1; return; } - int32_t range_offset = fsa_idx0 * stride_0 + t * stride_1 + token_index * stride_2; + if (state_idx1 == x_lens_data[fsa_idx0] * U) { + // x_lens[fsa_idx0] * U is the state_idx1 of super final-state. + // final state has no leaving arcs. + s2c_row_splits_data[state_idx01] = 0; + return; + } + int32_t range_offset = fsa_idx0 * rng_stride_0 + t * rng_stride_1 + token_index * rng_stride_2; int32_t blank_connections_data_offset = fsa_idx0 * blk_stride_0 + t * blk_stride_1 + token_index * blk_stride_2; int32_t next_state_idx1 = blank_connections_data[blank_connections_data_offset]; + + int32_t next_state_idx1_tmp = -1; + // blank connections of last frame is -1 + // So we need process t == x_lens_data[fsa_idx0] + if (t < x_lens_data[fsa_idx0] - 1) { + int32_t range_offset_of_lower_bound_of_next_time_step = fsa_idx0 * rng_stride_0 + (t + 1) * rng_stride_1; + next_state_idx1_tmp = ranges_data[range_offset] - ranges_data[range_offset_of_lower_bound_of_next_time_step]; + } + // K2_CHECK_EQ(next_state_idx1_tmp, next_state_idx1); if (token_index < U - 1) { + // Typically, U == 5, + // the [0, 1, 2, 3] states for each time step may have a vertial arc plus an optional horizontal blank arc. s2c_row_splits_data[state_idx01] = 1; if (next_state_idx1 >= 0) { s2c_row_splits_data[state_idx01] = 2; } } else { - s2c_row_splits_data[state_idx01] = 0; - if (next_state_idx1 >= 0) { - s2c_row_splits_data[state_idx01] = 1; - } + // Typically, U == 5, + // the [4] state for each time step have and only have an horizontal blank arc. + // K2_CHECK_GE(next_state_idx1, 0); + s2c_row_splits_data[state_idx01] = 1; } }); @@ -139,6 +164,8 @@ FsaVec SelfAlignment( RaggedShape ofsa_shape = ComposeRaggedShapes(fsa_to_states, states_to_arcs); int32_t num_arcs = ofsa_shape.NumElements(); Array1 arcs(context, num_arcs); + Array1 out_map(context, num_arcs); + int32_t* out_map_data = out_map.Data(); Arc *arcs_data = arcs.Data(); const int32_t *row_splits1_data = ofsa_shape.RowSplits(1).Data(), *row_ids1_data = ofsa_shape.RowIds(1).Data(), @@ -161,19 +188,20 @@ FsaVec SelfAlignment( t = state_idx1 / U, token_index = state_idx1 % U; // token_index is belong to [0, U) Arc arc; - if (state_idx1 == x_lens_data[fsa_idx0] * U) { + if (state_idx1 == x_lens_data[fsa_idx0] * U - 1) { arc.src_state = state_idx1; arc.dest_state = state_idx1 + 1; arc.label = -1; arc.score = 0.0; arcs_data[arc_idx012] = arc; + out_map_data[arc_idx012] = -1; return; } - int32_t rangged_offset = fsa_idx0 * stride_0 + t * stride_1 + token_index * stride_2; - int32_t actual_u = ranges_data[rangged_offset]; + int32_t range_offset = fsa_idx0 * rng_stride_0 + t * rng_stride_1 + token_index * rng_stride_2; + int32_t actual_u = ranges_data[range_offset]; int32_t y_offset = fsa_idx0 * y_stride_0 + actual_u; int32_t arc_label = y_data[y_offset]; - int32_t blank_connections_data_offset, next_state_idx1; + int32_t blank_connections_data_offset, next_state_idx1, logits_offset; arc.src_state = state_idx1; // arc. if (token_index < U - 1) { @@ -181,13 +209,25 @@ FsaVec SelfAlignment( case 0: arc.dest_state = state_idx1 + 1; arc.label = arc_label; + logits_offset = fsa_idx0 * lg_stride_0 + t * lg_stride_1 + token_index * lg_stride_2 + arc_label * lg_stride_3; + // K2_CHECK_LE(logits_offset, 135); + arc.score = logits_data[logits_offset]; + out_map_data[arc_idx012] = logits_offset; + // arc.score = 0; break; case 1: blank_connections_data_offset = fsa_idx0 * blk_stride_0 + t * blk_stride_1 + token_index * blk_stride_2; next_state_idx1 = blank_connections_data[blank_connections_data_offset]; - K2_CHECK_GE(next_state_idx1, 0); + // blank connections of last frame is always -1, + // So states with num_arcs > 2 (i.e. with arc_idx2==1) could not belong to last frame, i.e. x_lens_data[fsa_idx0] - 1. + // K2_CHECK_LE(t, x_lens_data[fsa_idx0] - 1); + // K2_CHECK_GE(next_state_idx1, 0); arc.dest_state = next_state_idx1 + (t + 1) * U; arc.label = 0; + logits_offset = fsa_idx0 * lg_stride_0 + t * lg_stride_1 + token_index * lg_stride_2; + arc.score = logits_data[logits_offset]; + out_map_data[arc_idx012] = logits_offset; + // arc.score = 0.0; break; default: K2_LOG(FATAL) << "Arc index must be less than 3"; @@ -196,22 +236,27 @@ FsaVec SelfAlignment( K2_CHECK_EQ(arc_idx2, 0); blank_connections_data_offset = fsa_idx0 * blk_stride_0 + t * blk_stride_1 + token_index * blk_stride_2; next_state_idx1 = blank_connections_data[blank_connections_data_offset]; - K2_CHECK_GE(next_state_idx1, 0); + // K2_CHECK_GE(next_state_idx1, 0); arc.dest_state = next_state_idx1 + (t + 1) * U; arc.label = 0; + logits_offset = fsa_idx0 * lg_stride_0 + t * lg_stride_1 + token_index * lg_stride_2; + arc.score = logits_data[logits_offset]; + out_map_data[arc_idx012] = logits_offset; + // arc.score = 0.0; } arcs_data[arc_idx012] = arc; }); + *arc_map = std::move(out_map); return Ragged(ofsa_shape, arcs); - int64_t * re_reanges_data = re_ranges.Data(); - K2_LOG(INFO) << "numel: " << numel; - K2_EVAL(context, numel, lambda_set_range, (int32_t i) { - re_reanges_data[i] = ranges_data[i]; - }); - auto cpu_range = re_ranges.To(GetCpuContext()); + // int32_t * re_reanges_data = re_ranges.Data(); + // K2_LOG(INFO) << "numel: " << numel; + // K2_EVAL(context, numel, lambda_set_range, (int32_t i) { + // re_reanges_data[i] = ranges_data[i]; + // }); + // auto cpu_range = re_ranges.To(GetCpuContext()); } diff --git a/k2/csrc/self_alignment.h b/k2/csrc/self_alignment.h index 888fc435a..652d11d41 100644 --- a/k2/csrc/self_alignment.h +++ b/k2/csrc/self_alignment.h @@ -35,8 +35,7 @@ FsaVec SelfAlignment( torch::Tensor blank_connections, torch::Tensor y, // const Ragged &y, - torch::optional boundary, // [B][4], int64_t. - torch::Tensor p, + torch::Tensor logits, // FsaVec * ofsa, Array1 *arc_map); diff --git a/k2/python/csrc/torch/self_alignment.cu b/k2/python/csrc/torch/self_alignment.cu index f3de61cce..00fc11420 100644 --- a/k2/python/csrc/torch/self_alignment.cu +++ b/k2/python/csrc/torch/self_alignment.cu @@ -32,16 +32,15 @@ void PybindSelfAlignment(py::module &m) { [](torch::Tensor ranges, torch::Tensor x_lens, torch::Tensor blank_connections, torch::Tensor y, - torch::optional boundary, - torch::Tensor p) -> std::pair { + torch::Tensor logits) -> std::pair { k2::DeviceGuard guard(k2::GetContext(ranges)); k2::Array1 label_map; - k2::FsaVec ofsa = k2::SelfAlignment(ranges, x_lens, blank_connections, y, boundary, p, &label_map); + k2::FsaVec ofsa = k2::SelfAlignment(ranges, x_lens, blank_connections, y, logits, &label_map); // k2::SelfAlignment(ranges, x_lens, blank_connections, y.any.Specialize(), boundary, p, &ofsa, &label_map); // k2::SelfAlignment(ranges, x_lens, blank_connections, boundary, p, &ofsa, &label_map); torch::Tensor tensor = ToTorch(label_map); return std::make_pair(ofsa, tensor); }, - py::arg("ranges"), py::arg("x_lens"), py::arg("blank_connections"), py::arg("y"), py::arg("boundary"), py::arg("p")); + py::arg("ranges"), py::arg("x_lens"), py::arg("blank_connections"), py::arg("y"), py::arg("logits")); // py::arg("ranges"), py::arg("x_lens"), py::arg("blank_connections"), py::arg("boundary"), py::arg("p")); } From dea8116118cd08323582ae5358c5eec36ba98f00 Mon Sep 17 00:00:00 2001 From: Guo Liyong Date: Mon, 19 Dec 2022 20:30:19 +0800 Subject: [PATCH 03/18] integrate blank_connection computation --- k2/csrc/self_alignment.cu | 31 ++++++++++++++------------ k2/csrc/self_alignment.h | 2 +- k2/python/csrc/torch/self_alignment.cu | 6 ++--- 3 files changed, 21 insertions(+), 18 deletions(-) diff --git a/k2/csrc/self_alignment.cu b/k2/csrc/self_alignment.cu index c9a17023a..683432c5e 100644 --- a/k2/csrc/self_alignment.cu +++ b/k2/csrc/self_alignment.cu @@ -34,7 +34,7 @@ FsaVec SelfAlignment( // Currently, only [B][S][T] is supported. torch::Tensor ranges, torch::Tensor x_lens, // [B][T] - torch::Tensor blank_connections, + // torch::Tensor blank_connections, torch::Tensor y, // const Ragged &y, torch::Tensor logits, @@ -69,14 +69,14 @@ FsaVec SelfAlignment( const float *logits_data = logits.data_ptr(); const int32_t *ranges_data = ranges.data_ptr(); const int32_t *x_lens_data = x_lens.data_ptr(); - const int32_t *blank_connections_data = blank_connections.data_ptr(); + // const int32_t *blank_connections_data = blank_connections.data_ptr(); const int32_t *y_data = y.data_ptr(); int32_t rng_stride_0 = ranges.stride(0), rng_stride_1 = ranges.stride(1), rng_stride_2 = ranges.stride(2); - int32_t blk_stride_0 = blank_connections.stride(0), - blk_stride_1 = blank_connections.stride(1), - blk_stride_2 = blank_connections.stride(2); + // int32_t blk_stride_0 = blank_connections.stride(0), + // blk_stride_1 = blank_connections.stride(1), + // blk_stride_2 = blank_connections.stride(2); int32_t lg_stride_0 = logits.stride(0), lg_stride_1 = logits.stride(1), lg_stride_2 = logits.stride(2), @@ -131,15 +131,15 @@ FsaVec SelfAlignment( return; } int32_t range_offset = fsa_idx0 * rng_stride_0 + t * rng_stride_1 + token_index * rng_stride_2; - int32_t blank_connections_data_offset = fsa_idx0 * blk_stride_0 + t * blk_stride_1 + token_index * blk_stride_2; - int32_t next_state_idx1 = blank_connections_data[blank_connections_data_offset]; + // int32_t blank_connections_data_offset = fsa_idx0 * blk_stride_0 + t * blk_stride_1 + token_index * blk_stride_2; + // int32_t next_state_idx1 = blank_connections_data[blank_connections_data_offset]; - int32_t next_state_idx1_tmp = -1; + int32_t next_state_idx1 = -1; // blank connections of last frame is -1 // So we need process t == x_lens_data[fsa_idx0] if (t < x_lens_data[fsa_idx0] - 1) { int32_t range_offset_of_lower_bound_of_next_time_step = fsa_idx0 * rng_stride_0 + (t + 1) * rng_stride_1; - next_state_idx1_tmp = ranges_data[range_offset] - ranges_data[range_offset_of_lower_bound_of_next_time_step]; + next_state_idx1 = ranges_data[range_offset] - ranges_data[range_offset_of_lower_bound_of_next_time_step]; } // K2_CHECK_EQ(next_state_idx1_tmp, next_state_idx1); if (token_index < U - 1) { @@ -198,10 +198,11 @@ FsaVec SelfAlignment( return; } int32_t range_offset = fsa_idx0 * rng_stride_0 + t * rng_stride_1 + token_index * rng_stride_2; + int32_t range_offset_of_lower_bound_of_next_time_step = fsa_idx0 * rng_stride_0 + (t + 1) * rng_stride_1; int32_t actual_u = ranges_data[range_offset]; int32_t y_offset = fsa_idx0 * y_stride_0 + actual_u; int32_t arc_label = y_data[y_offset]; - int32_t blank_connections_data_offset, next_state_idx1, logits_offset; + int32_t next_state_idx1, logits_offset; arc.src_state = state_idx1; // arc. if (token_index < U - 1) { @@ -216,8 +217,9 @@ FsaVec SelfAlignment( // arc.score = 0; break; case 1: - blank_connections_data_offset = fsa_idx0 * blk_stride_0 + t * blk_stride_1 + token_index * blk_stride_2; - next_state_idx1 = blank_connections_data[blank_connections_data_offset]; + next_state_idx1 = ranges_data[range_offset] - ranges_data[range_offset_of_lower_bound_of_next_time_step]; + // blank_connections_data_offset = fsa_idx0 * blk_stride_0 + t * blk_stride_1 + token_index * blk_stride_2; + // next_state_idx1 = blank_connections_data[blank_connections_data_offset]; // blank connections of last frame is always -1, // So states with num_arcs > 2 (i.e. with arc_idx2==1) could not belong to last frame, i.e. x_lens_data[fsa_idx0] - 1. // K2_CHECK_LE(t, x_lens_data[fsa_idx0] - 1); @@ -234,8 +236,9 @@ FsaVec SelfAlignment( } } else { K2_CHECK_EQ(arc_idx2, 0); - blank_connections_data_offset = fsa_idx0 * blk_stride_0 + t * blk_stride_1 + token_index * blk_stride_2; - next_state_idx1 = blank_connections_data[blank_connections_data_offset]; + // blank_connections_data_offset = fsa_idx0 * blk_stride_0 + t * blk_stride_1 + token_index * blk_stride_2; + // next_state_idx1 = blank_connections_data[blank_connections_data_offset]; + next_state_idx1 = ranges_data[range_offset] - ranges_data[range_offset_of_lower_bound_of_next_time_step]; // K2_CHECK_GE(next_state_idx1, 0); arc.dest_state = next_state_idx1 + (t + 1) * U; arc.label = 0; diff --git a/k2/csrc/self_alignment.h b/k2/csrc/self_alignment.h index 652d11d41..d460b12b8 100644 --- a/k2/csrc/self_alignment.h +++ b/k2/csrc/self_alignment.h @@ -32,7 +32,7 @@ namespace k2 { FsaVec SelfAlignment( torch::Tensor ranges, // [B][S][T+1] if !modified, [B][S][T] if modified. torch::Tensor x_lens, // [B][S+1][T] - torch::Tensor blank_connections, + // torch::Tensor blank_connections, torch::Tensor y, // const Ragged &y, torch::Tensor logits, diff --git a/k2/python/csrc/torch/self_alignment.cu b/k2/python/csrc/torch/self_alignment.cu index 00fc11420..7e3356ed5 100644 --- a/k2/python/csrc/torch/self_alignment.cu +++ b/k2/python/csrc/torch/self_alignment.cu @@ -30,17 +30,17 @@ void PybindSelfAlignment(py::module &m) { m.def( "self_alignment", [](torch::Tensor ranges, torch::Tensor x_lens, - torch::Tensor blank_connections, torch::Tensor y, torch::Tensor logits) -> std::pair { k2::DeviceGuard guard(k2::GetContext(ranges)); k2::Array1 label_map; - k2::FsaVec ofsa = k2::SelfAlignment(ranges, x_lens, blank_connections, y, logits, &label_map); + k2::FsaVec ofsa = k2::SelfAlignment(ranges, x_lens, y, logits, &label_map); + // k2::FsaVec ofsa = k2::SelfAlignment(ranges, x_lens, blank_connections, y, logits, &label_map); // k2::SelfAlignment(ranges, x_lens, blank_connections, y.any.Specialize(), boundary, p, &ofsa, &label_map); // k2::SelfAlignment(ranges, x_lens, blank_connections, boundary, p, &ofsa, &label_map); torch::Tensor tensor = ToTorch(label_map); return std::make_pair(ofsa, tensor); }, - py::arg("ranges"), py::arg("x_lens"), py::arg("blank_connections"), py::arg("y"), py::arg("logits")); + py::arg("ranges"), py::arg("x_lens"), py::arg("y"), py::arg("logits")); // py::arg("ranges"), py::arg("x_lens"), py::arg("blank_connections"), py::arg("boundary"), py::arg("p")); } From 1613abc8e36858c1054c0c0671d6a4c97dfcf50c Mon Sep 17 00:00:00 2001 From: Guo Liyong Date: Sat, 14 Jan 2023 16:26:05 +0800 Subject: [PATCH 04/18] fix y_stride_1 --- k2/csrc/self_alignment.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/k2/csrc/self_alignment.cu b/k2/csrc/self_alignment.cu index 683432c5e..014ee7717 100644 --- a/k2/csrc/self_alignment.cu +++ b/k2/csrc/self_alignment.cu @@ -200,7 +200,7 @@ FsaVec SelfAlignment( int32_t range_offset = fsa_idx0 * rng_stride_0 + t * rng_stride_1 + token_index * rng_stride_2; int32_t range_offset_of_lower_bound_of_next_time_step = fsa_idx0 * rng_stride_0 + (t + 1) * rng_stride_1; int32_t actual_u = ranges_data[range_offset]; - int32_t y_offset = fsa_idx0 * y_stride_0 + actual_u; + int32_t y_offset = fsa_idx0 * y_stride_0 + actual_u * y_stride_1; int32_t arc_label = y_data[y_offset]; int32_t next_state_idx1, logits_offset; arc.src_state = state_idx1; From b448667ee565677b0334eba2ba3d086d43accaee Mon Sep 17 00:00:00 2001 From: Guo Liyong Date: Fri, 17 Feb 2023 11:05:01 +0800 Subject: [PATCH 05/18] fix crash --- k2/csrc/self_alignment.cu | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/k2/csrc/self_alignment.cu b/k2/csrc/self_alignment.cu index 014ee7717..588a181c5 100644 --- a/k2/csrc/self_alignment.cu +++ b/k2/csrc/self_alignment.cu @@ -44,6 +44,11 @@ FsaVec SelfAlignment( context = GetCpuContext(); } else if (ranges.is_cuda()) { context = GetCudaContext(ranges.device().index()); + + TORCH_CHECK(ranges.get_device() == x_lens.get_device(), "x_lens is on a different device"); + TORCH_CHECK(ranges.get_device() == y.get_device(), "y device"); + TORCH_CHECK(ranges.get_device() == logits.get_device(), "logits device"); + } else { K2_LOG(FATAL) << "Unsupported device: " << ranges.device() << "\nOnly CPU and CUDA are supported"; @@ -199,13 +204,13 @@ FsaVec SelfAlignment( } int32_t range_offset = fsa_idx0 * rng_stride_0 + t * rng_stride_1 + token_index * rng_stride_2; int32_t range_offset_of_lower_bound_of_next_time_step = fsa_idx0 * rng_stride_0 + (t + 1) * rng_stride_1; - int32_t actual_u = ranges_data[range_offset]; - int32_t y_offset = fsa_idx0 * y_stride_0 + actual_u * y_stride_1; - int32_t arc_label = y_data[y_offset]; int32_t next_state_idx1, logits_offset; arc.src_state = state_idx1; // arc. if (token_index < U - 1) { + int32_t actual_u = ranges_data[range_offset]; + int32_t y_offset = fsa_idx0 * y_stride_0 + actual_u * y_stride_1; + int32_t arc_label = y_data[y_offset]; switch (arc_idx2) { case 0: arc.dest_state = state_idx1 + 1; From 8564112d13a326d175ac9a1f279b0c3478204207 Mon Sep 17 00:00:00 2001 From: Guo Liyong Date: Fri, 17 Feb 2023 11:25:40 +0800 Subject: [PATCH 06/18] remove unused comments --- k2/csrc/self_alignment.cu | 54 ++++++-------------------- k2/csrc/self_alignment.h | 3 -- k2/python/csrc/torch/self_alignment.cu | 4 -- 3 files changed, 11 insertions(+), 50 deletions(-) diff --git a/k2/csrc/self_alignment.cu b/k2/csrc/self_alignment.cu index 588a181c5..ba9e06dd2 100644 --- a/k2/csrc/self_alignment.cu +++ b/k2/csrc/self_alignment.cu @@ -1,8 +1,6 @@ /** - * @copyright * Copyright 2022 Xiaomi Corporation (authors: Liyong Guo) * - * @copyright * See LICENSE for clarification regarding multiple authors * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -59,39 +57,20 @@ FsaVec SelfAlignment( const int32_t B = ranges.size(0), T = ranges.size(1), U = ranges.size(2); - // K2_LOG(INFO) << ranges[0][5]; - // K2_LOG(INFO) << ranges[0]; Dtype t = ScalarTypeToDtype(ranges.scalar_type()); - // K2_CHECK_EQ(kInt64Dtype, t); - // K2_CHECK_EQ(torch::kLong, ranges.scalar_type()); K2_CHECK_EQ(torch::kInt, ranges.scalar_type()); K2_CHECK_EQ(torch::kInt, x_lens.scalar_type()); // int32_t - // Dtype t = ScalarTypeToDtype(ranges.GetDtype()); - // std::is_same::value; - // static_assert(std::is_same::value, "ranges is not kInt64Dtype") - // const int32_t *row_ids_data = row_ids.data_ptr(); - // FOR_REAL_TYPES(t, t, { - const float *logits_data = logits.data_ptr(); - const int32_t *ranges_data = ranges.data_ptr(); - const int32_t *x_lens_data = x_lens.data_ptr(); - // const int32_t *blank_connections_data = blank_connections.data_ptr(); - const int32_t *y_data = y.data_ptr(); - int32_t rng_stride_0 = ranges.stride(0), - rng_stride_1 = ranges.stride(1), - rng_stride_2 = ranges.stride(2); - // int32_t blk_stride_0 = blank_connections.stride(0), - // blk_stride_1 = blank_connections.stride(1), - // blk_stride_2 = blank_connections.stride(2); - int32_t lg_stride_0 = logits.stride(0), - lg_stride_1 = logits.stride(1), - lg_stride_2 = logits.stride(2), - lg_stride_3 = logits.stride(3); - // K2_LOG(INFO) << "stride_0: " << stride_0; - // K2_LOG(INFO) << "stride_1: " << stride_1; - // K2_LOG(INFO) << "stride_2: " << stride_2; - // int32_t numel = ranges.numel(); - // Array1 re_ranges(context, numel); - // f2s: fsa to states + const float *logits_data = logits.data_ptr(); + const int32_t *ranges_data = ranges.data_ptr(); + const int32_t *x_lens_data = x_lens.data_ptr(); + const int32_t *y_data = y.data_ptr(); + const int32_t rng_stride_0 = ranges.stride(0), + rng_stride_1 = ranges.stride(1), + rng_stride_2 = ranges.stride(2); + const int32_t lg_stride_0 = logits.stride(0), + lg_stride_1 = logits.stride(1), + lg_stride_2 = logits.stride(2), + lg_stride_3 = logits.stride(3); K2_CHECK_EQ(x_lens.numel(), B); Array1 f2s_row_splits(context, B + 1); int32_t * f2s_row_splits_data = f2s_row_splits.Data(); @@ -136,8 +115,6 @@ FsaVec SelfAlignment( return; } int32_t range_offset = fsa_idx0 * rng_stride_0 + t * rng_stride_1 + token_index * rng_stride_2; - // int32_t blank_connections_data_offset = fsa_idx0 * blk_stride_0 + t * blk_stride_1 + token_index * blk_stride_2; - // int32_t next_state_idx1 = blank_connections_data[blank_connections_data_offset]; int32_t next_state_idx1 = -1; // blank connections of last frame is -1 @@ -257,15 +234,6 @@ FsaVec SelfAlignment( *arc_map = std::move(out_map); return Ragged(ofsa_shape, arcs); - - - // int32_t * re_reanges_data = re_ranges.Data(); - // K2_LOG(INFO) << "numel: " << numel; - // K2_EVAL(context, numel, lambda_set_range, (int32_t i) { - // re_reanges_data[i] = ranges_data[i]; - // }); - // auto cpu_range = re_ranges.To(GetCpuContext()); - } } // namespace k2 diff --git a/k2/csrc/self_alignment.h b/k2/csrc/self_alignment.h index d460b12b8..e0d5222be 100644 --- a/k2/csrc/self_alignment.h +++ b/k2/csrc/self_alignment.h @@ -32,11 +32,8 @@ namespace k2 { FsaVec SelfAlignment( torch::Tensor ranges, // [B][S][T+1] if !modified, [B][S][T] if modified. torch::Tensor x_lens, // [B][S+1][T] - // torch::Tensor blank_connections, torch::Tensor y, - // const Ragged &y, torch::Tensor logits, - // FsaVec * ofsa, Array1 *arc_map); } // namespace k2 diff --git a/k2/python/csrc/torch/self_alignment.cu b/k2/python/csrc/torch/self_alignment.cu index 7e3356ed5..516e7d81a 100644 --- a/k2/python/csrc/torch/self_alignment.cu +++ b/k2/python/csrc/torch/self_alignment.cu @@ -35,12 +35,8 @@ void PybindSelfAlignment(py::module &m) { k2::DeviceGuard guard(k2::GetContext(ranges)); k2::Array1 label_map; k2::FsaVec ofsa = k2::SelfAlignment(ranges, x_lens, y, logits, &label_map); - // k2::FsaVec ofsa = k2::SelfAlignment(ranges, x_lens, blank_connections, y, logits, &label_map); - // k2::SelfAlignment(ranges, x_lens, blank_connections, y.any.Specialize(), boundary, p, &ofsa, &label_map); - // k2::SelfAlignment(ranges, x_lens, blank_connections, boundary, p, &ofsa, &label_map); torch::Tensor tensor = ToTorch(label_map); return std::make_pair(ofsa, tensor); }, py::arg("ranges"), py::arg("x_lens"), py::arg("y"), py::arg("logits")); - // py::arg("ranges"), py::arg("x_lens"), py::arg("blank_connections"), py::arg("boundary"), py::arg("p")); } From b703774916ab4575622707ca62765ceffbb358db Mon Sep 17 00:00:00 2001 From: Guo Liyong Date: Sat, 6 May 2023 10:20:43 +0800 Subject: [PATCH 07/18] rename to pruned_ranges_to_lattice --- k2/csrc/CMakeLists.txt | 2 +- .../{self_alignment.cu => pruned_ranges_to_lattice.cu} | 8 ++++---- .../{self_alignment.h => pruned_ranges_to_lattice.h} | 8 ++++---- k2/python/csrc/torch.cu | 4 ++-- k2/python/csrc/torch/CMakeLists.txt | 2 +- .../{self_alignment.cu => pruned_ranges_to_lattice.cu} | 10 +++++----- .../{self_alignment.h => pruned_ranges_to_lattice.h} | 8 ++++---- k2/python/k2/__init__.py | 2 +- 8 files changed, 22 insertions(+), 22 deletions(-) rename k2/csrc/{self_alignment.cu => pruned_ranges_to_lattice.cu} (98%) rename k2/csrc/{self_alignment.h => pruned_ranges_to_lattice.h} (86%) rename k2/python/csrc/torch/{self_alignment.cu => pruned_ranges_to_lattice.cu} (82%) rename k2/python/csrc/torch/{self_alignment.h => pruned_ranges_to_lattice.h} (77%) diff --git a/k2/csrc/CMakeLists.txt b/k2/csrc/CMakeLists.txt index b61ee8eb8..73b75eb4d 100644 --- a/k2/csrc/CMakeLists.txt +++ b/k2/csrc/CMakeLists.txt @@ -74,7 +74,7 @@ set(context_srcs reverse.cu rm_epsilon.cu rnnt_decode.cu - self_alignment.cu + pruned_ranges_to_lattice.cu tensor.cu tensor_ops.cu thread_pool.cu diff --git a/k2/csrc/self_alignment.cu b/k2/csrc/pruned_ranges_to_lattice.cu similarity index 98% rename from k2/csrc/self_alignment.cu rename to k2/csrc/pruned_ranges_to_lattice.cu index ba9e06dd2..fa48c5b81 100644 --- a/k2/csrc/self_alignment.cu +++ b/k2/csrc/pruned_ranges_to_lattice.cu @@ -16,8 +16,8 @@ * limitations under the License. */ -#ifndef K2_CSRC_SELF_ALIGNMENT_H_ -#define K2_CSRC_SELF_ALIGNMENT_H_ +#ifndef K2_CSRC_PRUNE_RANGE_TO_LATTICE_H_ +#define K2_CSRC_PRUNE_RANGE_TO_LATTICE_H_ #include @@ -27,7 +27,7 @@ namespace k2 { -FsaVec SelfAlignment( +FsaVec PrunedRangesToLattice( // Normally, ranges is with shape [B][S][T+1] if !modified, [B][S][T] if modified. // Currently, only [B][S][T] is supported. torch::Tensor ranges, @@ -238,4 +238,4 @@ FsaVec SelfAlignment( } // namespace k2 -#endif // K2_CSRC_SELF_ALIGNMENT_H_ +#endif // K2_CSRC_PRUNE_RANGE_TO_LATTICE_H_ diff --git a/k2/csrc/self_alignment.h b/k2/csrc/pruned_ranges_to_lattice.h similarity index 86% rename from k2/csrc/self_alignment.h rename to k2/csrc/pruned_ranges_to_lattice.h index e0d5222be..845d670ac 100644 --- a/k2/csrc/self_alignment.h +++ b/k2/csrc/pruned_ranges_to_lattice.h @@ -18,8 +18,8 @@ * limitations under the License. */ -#ifndef K2_CSRC_SELF_ALIGNMENT_H_ -#define K2_CSRC_SELF_ALIGNMENT_H_ +#ifndef K2_CSRC_PRUNE_RANGE_TO_LATTICE_H_ +#define K2_CSRC_PRUNE_RANGE_TO_LATTICE_H_ #include @@ -29,7 +29,7 @@ namespace k2 { -FsaVec SelfAlignment( +FsaVec PrunedRangesToLattice( torch::Tensor ranges, // [B][S][T+1] if !modified, [B][S][T] if modified. torch::Tensor x_lens, // [B][S+1][T] torch::Tensor y, @@ -38,4 +38,4 @@ FsaVec SelfAlignment( } // namespace k2 -#endif // K2_CSRC_SELF_ALIGNMENT_H_ +#endif // K2_CSRC_PRUNE_RANGE_TO_LATTICE_H_ diff --git a/k2/python/csrc/torch.cu b/k2/python/csrc/torch.cu index 0b35d04b3..72fe18151 100644 --- a/k2/python/csrc/torch.cu +++ b/k2/python/csrc/torch.cu @@ -31,10 +31,10 @@ #include "k2/python/csrc/torch/index_select.h" #include "k2/python/csrc/torch/mutual_information.h" #include "k2/python/csrc/torch/nbest.h" +#include "k2/python/csrc/torch/pruned_ranges_to_lattice.h" #include "k2/python/csrc/torch/ragged.h" #include "k2/python/csrc/torch/ragged_ops.h" #include "k2/python/csrc/torch/rnnt_decode.h" -#include "k2/python/csrc/torch/self_alignment.h" #include "k2/python/csrc/torch/v2/k2.h" void PybindTorch(py::module &m) { @@ -48,7 +48,7 @@ void PybindTorch(py::module &m) { PybindRagged(m); PybindRaggedOps(m); PybindRnntDecode(m); - PybindSelfAlignment(m); + PybindPrunedRangesToLattice(m); k2::PybindV2(m); } diff --git a/k2/python/csrc/torch/CMakeLists.txt b/k2/python/csrc/torch/CMakeLists.txt index 8be261ca5..788db692f 100644 --- a/k2/python/csrc/torch/CMakeLists.txt +++ b/k2/python/csrc/torch/CMakeLists.txt @@ -1,6 +1,6 @@ # please keep the list sorted set(torch_srcs - self_alignment.cu + pruned_ranges_to_lattice.cu arc.cu fsa.cu fsa_algo.cu diff --git a/k2/python/csrc/torch/self_alignment.cu b/k2/python/csrc/torch/pruned_ranges_to_lattice.cu similarity index 82% rename from k2/python/csrc/torch/self_alignment.cu rename to k2/python/csrc/torch/pruned_ranges_to_lattice.cu index 516e7d81a..53c3ed02e 100644 --- a/k2/python/csrc/torch/self_alignment.cu +++ b/k2/python/csrc/torch/pruned_ranges_to_lattice.cu @@ -21,20 +21,20 @@ #include "k2/csrc/device_guard.h" #include "k2/csrc/fsa.h" #include "k2/csrc/torch_util.h" -#include "k2/csrc/self_alignment.h" -#include "k2/python/csrc/torch/self_alignment.h" +#include "k2/csrc/pruned_ranges_to_lattice.h" +#include "k2/python/csrc/torch/pruned_ranges_to_lattice.h" #include "k2/python/csrc/torch/v2/ragged_any.h" -void PybindSelfAlignment(py::module &m) { +void PybindPrunedRangesToLattice(py::module &m) { m.def( - "self_alignment", + "pruned_ranges_to_lattice", [](torch::Tensor ranges, torch::Tensor x_lens, torch::Tensor y, torch::Tensor logits) -> std::pair { k2::DeviceGuard guard(k2::GetContext(ranges)); k2::Array1 label_map; - k2::FsaVec ofsa = k2::SelfAlignment(ranges, x_lens, y, logits, &label_map); + k2::FsaVec ofsa = k2::PrunedRangesToLattice(ranges, x_lens, y, logits, &label_map); torch::Tensor tensor = ToTorch(label_map); return std::make_pair(ofsa, tensor); }, diff --git a/k2/python/csrc/torch/self_alignment.h b/k2/python/csrc/torch/pruned_ranges_to_lattice.h similarity index 77% rename from k2/python/csrc/torch/self_alignment.h rename to k2/python/csrc/torch/pruned_ranges_to_lattice.h index 30033979a..b64f59bf0 100644 --- a/k2/python/csrc/torch/self_alignment.h +++ b/k2/python/csrc/torch/pruned_ranges_to_lattice.h @@ -18,11 +18,11 @@ * limitations under the License. */ -#ifndef K2_PYTHON_CSRC_TORCH_SELF_ALIGNMENT_H_ -#define K2_PYTHON_CSRC_TORCH_SELF_ALIGNMENT_H_ +#ifndef K2_PYTHON_CSRC_TORCH_PRUNE_RANGE_TO_LATTICE_H_ +#define K2_PYTHON_CSRC_TORCH_PRUNE_RANGE_TO_LATTICE_H_ #include "k2/python/csrc/torch.h" -void PybindSelfAlignment(py::module &m); +void PybindPrunedRangesToLattice(py::module &m); -#endif // K2_PYTHON_CSRC_TORCH_SELF_ALIGNMENT_H_ +#endif // K2_PYTHON_CSRC_TORCH_PRUNE_RANGE_TO_LATTICE_H_ diff --git a/k2/python/k2/__init__.py b/k2/python/k2/__init__.py index cc9e2d612..7ddf738fa 100644 --- a/k2/python/k2/__init__.py +++ b/k2/python/k2/__init__.py @@ -122,7 +122,7 @@ from .utils import random_fsa from .utils import random_fsa_vec from _k2.version import with_cuda -from _k2 import self_alignment +from _k2 import pruned_ranges_to_lattice from .decode import get_aux_labels from .decode import get_lattice From ecc2db3fc49ceb9de8f1a813e0fe7e0e2829b3c3 Mon Sep 17 00:00:00 2001 From: Guo Liyong Date: Sat, 6 May 2023 15:47:50 +0800 Subject: [PATCH 08/18] rename x_lens and y to frames and symbols --- k2/csrc/pruned_ranges_to_lattice.cu | 107 ++++++++++++---------------- k2/csrc/pruned_ranges_to_lattice.h | 14 ++-- 2 files changed, 54 insertions(+), 67 deletions(-) diff --git a/k2/csrc/pruned_ranges_to_lattice.cu b/k2/csrc/pruned_ranges_to_lattice.cu index fa48c5b81..c877d1684 100644 --- a/k2/csrc/pruned_ranges_to_lattice.cu +++ b/k2/csrc/pruned_ranges_to_lattice.cu @@ -16,9 +16,6 @@ * limitations under the License. */ -#ifndef K2_CSRC_PRUNE_RANGE_TO_LATTICE_H_ -#define K2_CSRC_PRUNE_RANGE_TO_LATTICE_H_ - #include #include @@ -28,14 +25,10 @@ namespace k2 { FsaVec PrunedRangesToLattice( - // Normally, ranges is with shape [B][S][T+1] if !modified, [B][S][T] if modified. - // Currently, only [B][S][T] is supported. - torch::Tensor ranges, - torch::Tensor x_lens, // [B][T] - // torch::Tensor blank_connections, - torch::Tensor y, - // const Ragged &y, - torch::Tensor logits, + torch::Tensor ranges, // [B][T][s_range] + torch::Tensor frames, // [B] + torch::Tensor symbols, // [B][S] + torch::Tensor logits, // [B][S][s_range][C] Array1 *arc_map) { ContextPtr context; if (ranges.device().type() == torch::kCPU) { @@ -43,27 +36,41 @@ FsaVec PrunedRangesToLattice( } else if (ranges.is_cuda()) { context = GetCudaContext(ranges.device().index()); - TORCH_CHECK(ranges.get_device() == x_lens.get_device(), "x_lens is on a different device"); - TORCH_CHECK(ranges.get_device() == y.get_device(), "y device"); - TORCH_CHECK(ranges.get_device() == logits.get_device(), "logits device"); + TORCH_CHECK(ranges.get_device() == frames.get_device()); + TORCH_CHECK(ranges.get_device() == symbols.get_device()); + TORCH_CHECK(ranges.get_device() == logits.get_device()); } else { K2_LOG(FATAL) << "Unsupported device: " << ranges.device() - << "\nOnly CPU and CUDA are supported"; + << "\nOnly CPU and CUDA are verified"; } - - TORCH_CHECK(ranges.dim() == 3, "ranges must be 3-dimensional"); - // U is always 5 - const int32_t B = ranges.size(0), T = ranges.size(1), U = ranges.size(2); - - Dtype t = ScalarTypeToDtype(ranges.scalar_type()); - K2_CHECK_EQ(torch::kInt, ranges.scalar_type()); - K2_CHECK_EQ(torch::kInt, x_lens.scalar_type()); // int32_t + TORCH_CHECK(ranges.dim() == 3, "ranges should be 3-dimensional"); + TORCH_CHECK(frames.dim() == 1, "frames should be 1-dimensional"); + TORCH_CHECK(symbols.dim() == 2, "symbols should be 2-dimensional"); + TORCH_CHECK(logits.dim() == 4, "logits should be 4-dimensional"); + + TORCH_CHECK(torch::kInt == ranges.scalar_type()); + TORCH_CHECK(torch::kInt == frames.scalar_type()); + TORCH_CHECK(torch::kInt == symbols.scalar_type()); + + // TODO: Support double and half. + // currently only float type logits is verified. + TORCH_CHECK(torch::kFloat == logits.scalar_type()); + + AT_DISPATCH_FLOATING_TYPES( + logits.scalar_type(), "pruned_ranges_to_lattice", ([&] { + auto ranges_a = ranges.accessor(); + auto frames_a = frames.accessor(); + auto symbols_a = symbols.accessor(); + auto logits_a = logits.accessor(); + })); + // Typically, U is 5. + const int32_t B = ranges.size(0), T = ranges.size(1), U = ranges.size(2); const float *logits_data = logits.data_ptr(); const int32_t *ranges_data = ranges.data_ptr(); - const int32_t *x_lens_data = x_lens.data_ptr(); - const int32_t *y_data = y.data_ptr(); + const int32_t *frames_data = frames.data_ptr(); + const int32_t *symbols_data = symbols.data_ptr(); const int32_t rng_stride_0 = ranges.stride(0), rng_stride_1 = ranges.stride(1), rng_stride_2 = ranges.stride(2); @@ -71,11 +78,11 @@ FsaVec PrunedRangesToLattice( lg_stride_1 = logits.stride(1), lg_stride_2 = logits.stride(2), lg_stride_3 = logits.stride(3); - K2_CHECK_EQ(x_lens.numel(), B); + K2_CHECK_EQ(frames.numel(), B); Array1 f2s_row_splits(context, B + 1); int32_t * f2s_row_splits_data = f2s_row_splits.Data(); K2_EVAL(context, B, lambda_set_f2s_row_splits, (int32_t fsa_idx0) { - int32_t t = x_lens_data[fsa_idx0]; + int32_t t = frames_data[fsa_idx0]; K2_CHECK_LE(t, T); // + 1 in "t * U + 1" is for super-final state. f2s_row_splits_data[fsa_idx0] = t * U + 1; @@ -95,21 +102,20 @@ FsaVec PrunedRangesToLattice( context, num_states, lambda_set_num_arcs, (int32_t state_idx01)->void { int32_t fsa_idx0 = fts_row_ids1_data[state_idx01], state_idx0x = fts_row_splits1_data[fsa_idx0], - state_idx0x_next = fts_row_splits1_data[fsa_idx0 + 1], state_idx1 = state_idx01 - state_idx0x, t = state_idx1 / U, token_index = state_idx1 % U; - K2_CHECK_LE(t, x_lens_data[fsa_idx0]); - if (state_idx1 == x_lens_data[fsa_idx0] * U - 1) { - // x_lens[fsa_idx0] * U is the state_idx1 of super final-state. - // x_lens[fsa_idx0] * U - 1 is the state pointing to super final-state. + K2_CHECK_LE(t, frames_data[fsa_idx0]); + if (state_idx1 == frames_data[fsa_idx0] * U - 1) { + // frames[fsa_idx0] * U is the state_idx1 of super final-state. + // frames[fsa_idx0] * U - 1 is the state pointing to super final-state. // final arc to super final state. s2c_row_splits_data[state_idx01] = 1; return; } - if (state_idx1 == x_lens_data[fsa_idx0] * U) { - // x_lens[fsa_idx0] * U is the state_idx1 of super final-state. + if (state_idx1 == frames_data[fsa_idx0] * U) { + // frames[fsa_idx0] * U is the state_idx1 of super final-state. // final state has no leaving arcs. s2c_row_splits_data[state_idx01] = 0; return; @@ -118,8 +124,8 @@ FsaVec PrunedRangesToLattice( int32_t next_state_idx1 = -1; // blank connections of last frame is -1 - // So we need process t == x_lens_data[fsa_idx0] - if (t < x_lens_data[fsa_idx0] - 1) { + // So we need process t == frames_data[fsa_idx0] + if (t < frames_data[fsa_idx0] - 1) { int32_t range_offset_of_lower_bound_of_next_time_step = fsa_idx0 * rng_stride_0 + (t + 1) * rng_stride_1; next_state_idx1 = ranges_data[range_offset] - ranges_data[range_offset_of_lower_bound_of_next_time_step]; } @@ -154,23 +160,20 @@ FsaVec PrunedRangesToLattice( *row_splits2_data = ofsa_shape.RowSplits(2).Data(), *row_ids2_data = ofsa_shape.RowIds(2).Data(); - // auto y_shape = y.shape; - // const int32_t * y_data = y.values.Data(); - int32_t y_stride_0 = y.stride(0), - y_stride_1 = y.stride(1); + int32_t symbols_stride_0 = symbols.stride(0), + symbols_stride_1 = symbols.stride(1); K2_EVAL( context, num_arcs, lambda_set_arcs, (int32_t arc_idx012)->void { int32_t state_idx01 = row_ids2_data[arc_idx012], fsa_idx0 = row_ids1_data[state_idx01], state_idx0x = row_splits1_data[fsa_idx0], - state_idx0x_next = row_splits1_data[fsa_idx0 + 1], arc_idx01x = row_splits2_data[state_idx01], state_idx1 = state_idx01 - state_idx0x, arc_idx2 = arc_idx012 - arc_idx01x, t = state_idx1 / U, token_index = state_idx1 % U; // token_index is belong to [0, U) Arc arc; - if (state_idx1 == x_lens_data[fsa_idx0] * U - 1) { + if (state_idx1 == frames_data[fsa_idx0] * U - 1) { arc.src_state = state_idx1; arc.dest_state = state_idx1 + 1; arc.label = -1; @@ -186,56 +189,40 @@ FsaVec PrunedRangesToLattice( // arc. if (token_index < U - 1) { int32_t actual_u = ranges_data[range_offset]; - int32_t y_offset = fsa_idx0 * y_stride_0 + actual_u * y_stride_1; - int32_t arc_label = y_data[y_offset]; + int32_t symbols_offset = fsa_idx0 * symbols_stride_0 + actual_u * symbols_stride_1; + int32_t arc_label = symbols_data[symbols_offset]; switch (arc_idx2) { case 0: arc.dest_state = state_idx1 + 1; arc.label = arc_label; logits_offset = fsa_idx0 * lg_stride_0 + t * lg_stride_1 + token_index * lg_stride_2 + arc_label * lg_stride_3; - // K2_CHECK_LE(logits_offset, 135); arc.score = logits_data[logits_offset]; out_map_data[arc_idx012] = logits_offset; - // arc.score = 0; break; case 1: next_state_idx1 = ranges_data[range_offset] - ranges_data[range_offset_of_lower_bound_of_next_time_step]; - // blank_connections_data_offset = fsa_idx0 * blk_stride_0 + t * blk_stride_1 + token_index * blk_stride_2; - // next_state_idx1 = blank_connections_data[blank_connections_data_offset]; - // blank connections of last frame is always -1, - // So states with num_arcs > 2 (i.e. with arc_idx2==1) could not belong to last frame, i.e. x_lens_data[fsa_idx0] - 1. - // K2_CHECK_LE(t, x_lens_data[fsa_idx0] - 1); - // K2_CHECK_GE(next_state_idx1, 0); arc.dest_state = next_state_idx1 + (t + 1) * U; arc.label = 0; logits_offset = fsa_idx0 * lg_stride_0 + t * lg_stride_1 + token_index * lg_stride_2; arc.score = logits_data[logits_offset]; out_map_data[arc_idx012] = logits_offset; - // arc.score = 0.0; break; default: K2_LOG(FATAL) << "Arc index must be less than 3"; } } else { K2_CHECK_EQ(arc_idx2, 0); - // blank_connections_data_offset = fsa_idx0 * blk_stride_0 + t * blk_stride_1 + token_index * blk_stride_2; - // next_state_idx1 = blank_connections_data[blank_connections_data_offset]; next_state_idx1 = ranges_data[range_offset] - ranges_data[range_offset_of_lower_bound_of_next_time_step]; - // K2_CHECK_GE(next_state_idx1, 0); arc.dest_state = next_state_idx1 + (t + 1) * U; arc.label = 0; logits_offset = fsa_idx0 * lg_stride_0 + t * lg_stride_1 + token_index * lg_stride_2; arc.score = logits_data[logits_offset]; out_map_data[arc_idx012] = logits_offset; - // arc.score = 0.0; } arcs_data[arc_idx012] = arc; }); *arc_map = std::move(out_map); return Ragged(ofsa_shape, arcs); - } } // namespace k2 - -#endif // K2_CSRC_PRUNE_RANGE_TO_LATTICE_H_ diff --git a/k2/csrc/pruned_ranges_to_lattice.h b/k2/csrc/pruned_ranges_to_lattice.h index 845d670ac..c48499e27 100644 --- a/k2/csrc/pruned_ranges_to_lattice.h +++ b/k2/csrc/pruned_ranges_to_lattice.h @@ -18,8 +18,8 @@ * limitations under the License. */ -#ifndef K2_CSRC_PRUNE_RANGE_TO_LATTICE_H_ -#define K2_CSRC_PRUNE_RANGE_TO_LATTICE_H_ +#ifndef K2_CSRC_PRUNED_RANGES_TO_LATTICE_H_ +#define K2_CSRC_PRUNED_RANGES_TO_LATTICE_H_ #include @@ -30,12 +30,12 @@ namespace k2 { FsaVec PrunedRangesToLattice( - torch::Tensor ranges, // [B][S][T+1] if !modified, [B][S][T] if modified. - torch::Tensor x_lens, // [B][S+1][T] - torch::Tensor y, - torch::Tensor logits, + torch::Tensor ranges, // [B][T][s_range] + torch::Tensor frames, // [B][T] + torch::Tensor symbols, // [B][S] + torch::Tensor logits, // [B][T][s_range][C] Array1 *arc_map); } // namespace k2 -#endif // K2_CSRC_PRUNE_RANGE_TO_LATTICE_H_ +#endif // K2_CSRC_PRUNED_RANGES_TO_LATTICE_H_ From dbb590a217d3be6ad8309f03b0a015940580438b Mon Sep 17 00:00:00 2001 From: Guo Liyong Date: Sat, 6 May 2023 17:42:57 +0800 Subject: [PATCH 09/18] use range accessor --- k2/csrc/pruned_ranges_to_lattice.cu | 84 +++++++++++++---------------- 1 file changed, 38 insertions(+), 46 deletions(-) diff --git a/k2/csrc/pruned_ranges_to_lattice.cu b/k2/csrc/pruned_ranges_to_lattice.cu index c877d1684..dc9a962f3 100644 --- a/k2/csrc/pruned_ranges_to_lattice.cu +++ b/k2/csrc/pruned_ranges_to_lattice.cu @@ -58,44 +58,44 @@ FsaVec PrunedRangesToLattice( // currently only float type logits is verified. TORCH_CHECK(torch::kFloat == logits.scalar_type()); - AT_DISPATCH_FLOATING_TYPES( - logits.scalar_type(), "pruned_ranges_to_lattice", ([&] { + // AT_DISPATCH_FLOATING_TYPES( + // logits.scalar_type(), "pruned_ranges_to_lattice", ([&] { auto ranges_a = ranges.accessor(); auto frames_a = frames.accessor(); auto symbols_a = symbols.accessor(); - auto logits_a = logits.accessor(); - })); - // Typically, U is 5. - const int32_t B = ranges.size(0), T = ranges.size(1), U = ranges.size(2); + // auto logits_a = logits.accessor(); + auto logits_a = logits.accessor(); + // })); + + // Typically, s_range is 5. + const int32_t B = ranges.size(0), T = ranges.size(1), s_range = ranges.size(2); const float *logits_data = logits.data_ptr(); const int32_t *ranges_data = ranges.data_ptr(); const int32_t *frames_data = frames.data_ptr(); const int32_t *symbols_data = symbols.data_ptr(); - const int32_t rng_stride_0 = ranges.stride(0), - rng_stride_1 = ranges.stride(1), - rng_stride_2 = ranges.stride(2); const int32_t lg_stride_0 = logits.stride(0), lg_stride_1 = logits.stride(1), lg_stride_2 = logits.stride(2), lg_stride_3 = logits.stride(3); K2_CHECK_EQ(frames.numel(), B); + // f2s is short for fsa_to_state. Array1 f2s_row_splits(context, B + 1); int32_t * f2s_row_splits_data = f2s_row_splits.Data(); K2_EVAL(context, B, lambda_set_f2s_row_splits, (int32_t fsa_idx0) { int32_t t = frames_data[fsa_idx0]; K2_CHECK_LE(t, T); // + 1 in "t * U + 1" is for super-final state. - f2s_row_splits_data[fsa_idx0] = t * U + 1; + f2s_row_splits_data[fsa_idx0] = t * s_range + 1; }); ExclusiveSum(f2s_row_splits, &f2s_row_splits); - RaggedShape fsa_to_states = + RaggedShape f2s_shape = RaggedShape2(&f2s_row_splits, nullptr, -1); - int32_t num_states = fsa_to_states.NumElements(); + int32_t num_states = f2s_shape.NumElements(); Array1 s2c_row_splits(context, num_states + 1); int32_t *s2c_row_splits_data = s2c_row_splits.Data(); - const int32_t *fts_row_splits1_data = fsa_to_states.RowSplits(1).Data(), - *fts_row_ids1_data = fsa_to_states.RowIds(1).Data(); + const int32_t *fts_row_splits1_data = f2s_shape.RowSplits(1).Data(), + *fts_row_ids1_data = f2s_shape.RowIds(1).Data(); // set the arcs number for each state K2_EVAL( @@ -103,53 +103,47 @@ FsaVec PrunedRangesToLattice( int32_t fsa_idx0 = fts_row_ids1_data[state_idx01], state_idx0x = fts_row_splits1_data[fsa_idx0], state_idx1 = state_idx01 - state_idx0x, - t = state_idx1 / U, - token_index = state_idx1 % U; + t = state_idx1 / s_range, + token_index = state_idx1 % s_range; K2_CHECK_LE(t, frames_data[fsa_idx0]); - if (state_idx1 == frames_data[fsa_idx0] * U - 1) { - // frames[fsa_idx0] * U is the state_idx1 of super final-state. - // frames[fsa_idx0] * U - 1 is the state pointing to super final-state. + if (state_idx1 == frames_data[fsa_idx0] * s_range - 1) { + // frames[fsa_idx0] * s_range is the state_idx1 of super final-state. + // frames[fsa_idx0] * s_range - 1 is the state pointing to super final-state. // final arc to super final state. s2c_row_splits_data[state_idx01] = 1; return; } - if (state_idx1 == frames_data[fsa_idx0] * U) { + if (state_idx1 == frames_data[fsa_idx0] * s_range) { // frames[fsa_idx0] * U is the state_idx1 of super final-state. // final state has no leaving arcs. s2c_row_splits_data[state_idx01] = 0; return; } - int32_t range_offset = fsa_idx0 * rng_stride_0 + t * rng_stride_1 + token_index * rng_stride_2; - int32_t next_state_idx1 = -1; - // blank connections of last frame is -1 - // So we need process t == frames_data[fsa_idx0] + bool has_horizontal_blank_arc = false; if (t < frames_data[fsa_idx0] - 1) { - int32_t range_offset_of_lower_bound_of_next_time_step = fsa_idx0 * rng_stride_0 + (t + 1) * rng_stride_1; - next_state_idx1 = ranges_data[range_offset] - ranges_data[range_offset_of_lower_bound_of_next_time_step]; + has_horizontal_blank_arc = ranges_a[fsa_idx0][t][token_index] >= ranges_a[fsa_idx0][t + 1][0]; } - // K2_CHECK_EQ(next_state_idx1_tmp, next_state_idx1); - if (token_index < U - 1) { - // Typically, U == 5, + if (token_index < s_range - 1) { + // Typically, s_range == 5, // the [0, 1, 2, 3] states for each time step may have a vertial arc plus an optional horizontal blank arc. s2c_row_splits_data[state_idx01] = 1; - if (next_state_idx1 >= 0) { + if (has_horizontal_blank_arc) { s2c_row_splits_data[state_idx01] = 2; } } else { - // Typically, U == 5, - // the [4] state for each time step have and only have an horizontal blank arc. - // K2_CHECK_GE(next_state_idx1, 0); + // Typically, s_range == 5, + // the [4] state for each time step have and only have a horizontal blank arc. s2c_row_splits_data[state_idx01] = 1; } }); ExclusiveSum(s2c_row_splits, &s2c_row_splits); - RaggedShape states_to_arcs = + RaggedShape s2a_shape = RaggedShape2(&s2c_row_splits, nullptr, -1); - RaggedShape ofsa_shape = ComposeRaggedShapes(fsa_to_states, states_to_arcs); + RaggedShape ofsa_shape = ComposeRaggedShapes(f2s_shape, s2a_shape); int32_t num_arcs = ofsa_shape.NumElements(); Array1 arcs(context, num_arcs); Array1 out_map(context, num_arcs); @@ -170,10 +164,10 @@ FsaVec PrunedRangesToLattice( arc_idx01x = row_splits2_data[state_idx01], state_idx1 = state_idx01 - state_idx0x, arc_idx2 = arc_idx012 - arc_idx01x, - t = state_idx1 / U, - token_index = state_idx1 % U; // token_index is belong to [0, U) + t = state_idx1 / s_range, + token_index = state_idx1 % s_range; // token_index is belong to [0, U) Arc arc; - if (state_idx1 == frames_data[fsa_idx0] * U - 1) { + if (state_idx1 == frames_data[fsa_idx0] * s_range - 1) { arc.src_state = state_idx1; arc.dest_state = state_idx1 + 1; arc.label = -1; @@ -182,13 +176,11 @@ FsaVec PrunedRangesToLattice( out_map_data[arc_idx012] = -1; return; } - int32_t range_offset = fsa_idx0 * rng_stride_0 + t * rng_stride_1 + token_index * rng_stride_2; - int32_t range_offset_of_lower_bound_of_next_time_step = fsa_idx0 * rng_stride_0 + (t + 1) * rng_stride_1; int32_t next_state_idx1, logits_offset; arc.src_state = state_idx1; // arc. - if (token_index < U - 1) { - int32_t actual_u = ranges_data[range_offset]; + if (token_index < s_range - 1) { + int32_t actual_u = ranges_a[fsa_idx0][t][token_index]; int32_t symbols_offset = fsa_idx0 * symbols_stride_0 + actual_u * symbols_stride_1; int32_t arc_label = symbols_data[symbols_offset]; switch (arc_idx2) { @@ -200,8 +192,8 @@ FsaVec PrunedRangesToLattice( out_map_data[arc_idx012] = logits_offset; break; case 1: - next_state_idx1 = ranges_data[range_offset] - ranges_data[range_offset_of_lower_bound_of_next_time_step]; - arc.dest_state = next_state_idx1 + (t + 1) * U; + next_state_idx1 = ranges_a[fsa_idx0][t][token_index] - ranges_a[fsa_idx0][t + 1][0]; + arc.dest_state = next_state_idx1 + (t + 1) * s_range; arc.label = 0; logits_offset = fsa_idx0 * lg_stride_0 + t * lg_stride_1 + token_index * lg_stride_2; arc.score = logits_data[logits_offset]; @@ -212,8 +204,8 @@ FsaVec PrunedRangesToLattice( } } else { K2_CHECK_EQ(arc_idx2, 0); - next_state_idx1 = ranges_data[range_offset] - ranges_data[range_offset_of_lower_bound_of_next_time_step]; - arc.dest_state = next_state_idx1 + (t + 1) * U; + next_state_idx1 = ranges_a[fsa_idx0][t][token_index] - ranges_a[fsa_idx0][t + 1][0]; + arc.dest_state = next_state_idx1 + (t + 1) * s_range; arc.label = 0; logits_offset = fsa_idx0 * lg_stride_0 + t * lg_stride_1 + token_index * lg_stride_2; arc.score = logits_data[logits_offset]; From d00fd0c154e63a1f82e503ba5dfedee0fdc0d31d Mon Sep 17 00:00:00 2001 From: Guo Liyong Date: Sat, 6 May 2023 19:03:13 +0800 Subject: [PATCH 10/18] use frame accessor --- k2/csrc/pruned_ranges_to_lattice.cu | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/k2/csrc/pruned_ranges_to_lattice.cu b/k2/csrc/pruned_ranges_to_lattice.cu index dc9a962f3..78504a235 100644 --- a/k2/csrc/pruned_ranges_to_lattice.cu +++ b/k2/csrc/pruned_ranges_to_lattice.cu @@ -69,10 +69,9 @@ FsaVec PrunedRangesToLattice( // Typically, s_range is 5. const int32_t B = ranges.size(0), T = ranges.size(1), s_range = ranges.size(2); - const float *logits_data = logits.data_ptr(); - const int32_t *ranges_data = ranges.data_ptr(); - const int32_t *frames_data = frames.data_ptr(); + // const int32_t *frames_data = frames.data_ptr(); const int32_t *symbols_data = symbols.data_ptr(); + // Used to compute out_map. const int32_t lg_stride_0 = logits.stride(0), lg_stride_1 = logits.stride(1), lg_stride_2 = logits.stride(2), @@ -82,9 +81,9 @@ FsaVec PrunedRangesToLattice( Array1 f2s_row_splits(context, B + 1); int32_t * f2s_row_splits_data = f2s_row_splits.Data(); K2_EVAL(context, B, lambda_set_f2s_row_splits, (int32_t fsa_idx0) { - int32_t t = frames_data[fsa_idx0]; + int32_t t = frames_a[fsa_idx0]; K2_CHECK_LE(t, T); - // + 1 in "t * U + 1" is for super-final state. + // + 1 in "t * s_range + 1" is for super-final state. f2s_row_splits_data[fsa_idx0] = t * s_range + 1; }); ExclusiveSum(f2s_row_splits, &f2s_row_splits); @@ -106,15 +105,15 @@ FsaVec PrunedRangesToLattice( t = state_idx1 / s_range, token_index = state_idx1 % s_range; - K2_CHECK_LE(t, frames_data[fsa_idx0]); - if (state_idx1 == frames_data[fsa_idx0] * s_range - 1) { + K2_CHECK_LE(t, frames_a[fsa_idx0]); + if (state_idx1 == frames_a[fsa_idx0] * s_range - 1) { // frames[fsa_idx0] * s_range is the state_idx1 of super final-state. // frames[fsa_idx0] * s_range - 1 is the state pointing to super final-state. // final arc to super final state. s2c_row_splits_data[state_idx01] = 1; return; } - if (state_idx1 == frames_data[fsa_idx0] * s_range) { + if (state_idx1 == frames_a[fsa_idx0] * s_range) { // frames[fsa_idx0] * U is the state_idx1 of super final-state. // final state has no leaving arcs. s2c_row_splits_data[state_idx01] = 0; @@ -122,7 +121,7 @@ FsaVec PrunedRangesToLattice( } bool has_horizontal_blank_arc = false; - if (t < frames_data[fsa_idx0] - 1) { + if (t < frames_a[fsa_idx0] - 1) { has_horizontal_blank_arc = ranges_a[fsa_idx0][t][token_index] >= ranges_a[fsa_idx0][t + 1][0]; } if (token_index < s_range - 1) { @@ -167,7 +166,7 @@ FsaVec PrunedRangesToLattice( t = state_idx1 / s_range, token_index = state_idx1 % s_range; // token_index is belong to [0, U) Arc arc; - if (state_idx1 == frames_data[fsa_idx0] * s_range - 1) { + if (state_idx1 == frames_a[fsa_idx0] * s_range - 1) { arc.src_state = state_idx1; arc.dest_state = state_idx1 + 1; arc.label = -1; @@ -187,16 +186,18 @@ FsaVec PrunedRangesToLattice( case 0: arc.dest_state = state_idx1 + 1; arc.label = arc_label; + arc.score = logits_a[fsa_idx0][t][token_index][arc_label]; + logits_offset = fsa_idx0 * lg_stride_0 + t * lg_stride_1 + token_index * lg_stride_2 + arc_label * lg_stride_3; - arc.score = logits_data[logits_offset]; out_map_data[arc_idx012] = logits_offset; break; case 1: next_state_idx1 = ranges_a[fsa_idx0][t][token_index] - ranges_a[fsa_idx0][t + 1][0]; arc.dest_state = next_state_idx1 + (t + 1) * s_range; arc.label = 0; + arc.score = logits_a[fsa_idx0][t][token_index][0]; + logits_offset = fsa_idx0 * lg_stride_0 + t * lg_stride_1 + token_index * lg_stride_2; - arc.score = logits_data[logits_offset]; out_map_data[arc_idx012] = logits_offset; break; default: @@ -207,8 +208,9 @@ FsaVec PrunedRangesToLattice( next_state_idx1 = ranges_a[fsa_idx0][t][token_index] - ranges_a[fsa_idx0][t + 1][0]; arc.dest_state = next_state_idx1 + (t + 1) * s_range; arc.label = 0; + arc.score = logits_a[fsa_idx0][t][token_index][0]; + logits_offset = fsa_idx0 * lg_stride_0 + t * lg_stride_1 + token_index * lg_stride_2; - arc.score = logits_data[logits_offset]; out_map_data[arc_idx012] = logits_offset; } arcs_data[arc_idx012] = arc; From fefd2780dd7369ae3ecf517460f2941d0abc40e2 Mon Sep 17 00:00:00 2001 From: Guo Liyong Date: Sat, 6 May 2023 19:07:01 +0800 Subject: [PATCH 11/18] use symbols_accessor --- k2/csrc/pruned_ranges_to_lattice.cu | 264 +++++++++++++++------------- 1 file changed, 137 insertions(+), 127 deletions(-) diff --git a/k2/csrc/pruned_ranges_to_lattice.cu b/k2/csrc/pruned_ranges_to_lattice.cu index 78504a235..480b61ff0 100644 --- a/k2/csrc/pruned_ranges_to_lattice.cu +++ b/k2/csrc/pruned_ranges_to_lattice.cu @@ -19,7 +19,9 @@ #include #include +#include +#include "k2/csrc/pruned_ranges_to_lattice.h" #include "k2/python/csrc/torch.h" namespace k2 { @@ -30,20 +32,10 @@ FsaVec PrunedRangesToLattice( torch::Tensor symbols, // [B][S] torch::Tensor logits, // [B][S][s_range][C] Array1 *arc_map) { - ContextPtr context; - if (ranges.device().type() == torch::kCPU) { - context = GetCpuContext(); - } else if (ranges.is_cuda()) { - context = GetCudaContext(ranges.device().index()); - TORCH_CHECK(ranges.get_device() == frames.get_device()); - TORCH_CHECK(ranges.get_device() == symbols.get_device()); - TORCH_CHECK(ranges.get_device() == logits.get_device()); - - } else { - K2_LOG(FATAL) << "Unsupported device: " << ranges.device() - << "\nOnly CPU and CUDA are verified"; - } + TORCH_CHECK(ranges.get_device() == frames.get_device()); + TORCH_CHECK(ranges.get_device() == symbols.get_device()); + TORCH_CHECK(ranges.get_device() == logits.get_device()); TORCH_CHECK(ranges.dim() == 3, "ranges should be 3-dimensional"); TORCH_CHECK(frames.dim() == 1, "frames should be 1-dimensional"); @@ -54,30 +46,27 @@ FsaVec PrunedRangesToLattice( TORCH_CHECK(torch::kInt == frames.scalar_type()); TORCH_CHECK(torch::kInt == symbols.scalar_type()); - // TODO: Support double and half. - // currently only float type logits is verified. - TORCH_CHECK(torch::kFloat == logits.scalar_type()); + ContextPtr context; + if (ranges.device().type() == torch::kCPU) { + context = GetCpuContext(); + } else if (ranges.is_cuda()) { + context = GetCudaContext(ranges.device().index()); + } else { + K2_LOG(FATAL) << "Unsupported device: " << ranges.device() + << "\nOnly CPU and CUDA are verified"; + } - // AT_DISPATCH_FLOATING_TYPES( - // logits.scalar_type(), "pruned_ranges_to_lattice", ([&] { - auto ranges_a = ranges.accessor(); - auto frames_a = frames.accessor(); - auto symbols_a = symbols.accessor(); - // auto logits_a = logits.accessor(); - auto logits_a = logits.accessor(); - // })); + // "_a" is short for accessor. + auto ranges_a = ranges.accessor(); + auto frames_a = frames.accessor(); + auto symbols_a = symbols.accessor(); // Typically, s_range is 5. - const int32_t B = ranges.size(0), T = ranges.size(1), s_range = ranges.size(2); - // const int32_t *frames_data = frames.data_ptr(); - const int32_t *symbols_data = symbols.data_ptr(); - // Used to compute out_map. - const int32_t lg_stride_0 = logits.stride(0), - lg_stride_1 = logits.stride(1), - lg_stride_2 = logits.stride(2), - lg_stride_3 = logits.stride(3); - K2_CHECK_EQ(frames.numel(), B); - // f2s is short for fsa_to_state. + const int32_t B = ranges.size(0), + T = ranges.size(1), + s_range = ranges.size(2); + + // Compute f2s_shape: fsa_to_state_shape. Array1 f2s_row_splits(context, B + 1); int32_t * f2s_row_splits_data = f2s_row_splits.Data(); K2_EVAL(context, B, lambda_set_f2s_row_splits, (int32_t fsa_idx0) { @@ -86,136 +75,157 @@ FsaVec PrunedRangesToLattice( // + 1 in "t * s_range + 1" is for super-final state. f2s_row_splits_data[fsa_idx0] = t * s_range + 1; }); - ExclusiveSum(f2s_row_splits, &f2s_row_splits); + ExclusiveSum(f2s_row_splits, &f2s_row_splits); RaggedShape f2s_shape = RaggedShape2(&f2s_row_splits, nullptr, -1); + + // Compute s2a_shape: state_to_arc_shape. int32_t num_states = f2s_shape.NumElements(); Array1 s2c_row_splits(context, num_states + 1); int32_t *s2c_row_splits_data = s2c_row_splits.Data(); - const int32_t *fts_row_splits1_data = f2s_shape.RowSplits(1).Data(), - *fts_row_ids1_data = f2s_shape.RowIds(1).Data(); - - // set the arcs number for each state + const int32_t *f2s_row_splits1_data = f2s_shape.RowSplits(1).Data(), + *f2s_row_ids1_data = f2s_shape.RowIds(1).Data(); + // Compute number of arcs for each state. K2_EVAL( context, num_states, lambda_set_num_arcs, (int32_t state_idx01)->void { - int32_t fsa_idx0 = fts_row_ids1_data[state_idx01], - state_idx0x = fts_row_splits1_data[fsa_idx0], + int32_t fsa_idx0 = f2s_row_ids1_data[state_idx01], + state_idx0x = f2s_row_splits1_data[fsa_idx0], state_idx1 = state_idx01 - state_idx0x, t = state_idx1 / s_range, - token_index = state_idx1 % s_range; + token_idx = state_idx1 % s_range; K2_CHECK_LE(t, frames_a[fsa_idx0]); - if (state_idx1 == frames_a[fsa_idx0] * s_range - 1) { - // frames[fsa_idx0] * s_range is the state_idx1 of super final-state. - // frames[fsa_idx0] * s_range - 1 is the state pointing to super final-state. - // final arc to super final state. - s2c_row_splits_data[state_idx01] = 1; - return; - } + + // The state doesn't have leaving arc: super final_state. if (state_idx1 == frames_a[fsa_idx0] * s_range) { - // frames[fsa_idx0] * U is the state_idx1 of super final-state. - // final state has no leaving arcs. s2c_row_splits_data[state_idx01] = 0; return; } + // States have a leaving arc if no specially processed. + s2c_row_splits_data[state_idx01] = 1; + + // States have two leaving arcs. bool has_horizontal_blank_arc = false; if (t < frames_a[fsa_idx0] - 1) { - has_horizontal_blank_arc = ranges_a[fsa_idx0][t][token_index] >= ranges_a[fsa_idx0][t + 1][0]; + has_horizontal_blank_arc = + ranges_a[fsa_idx0][t][token_idx] >= ranges_a[fsa_idx0][t + 1][0]; } - if (token_index < s_range - 1) { - // Typically, s_range == 5, - // the [0, 1, 2, 3] states for each time step may have a vertial arc plus an optional horizontal blank arc. - s2c_row_splits_data[state_idx01] = 1; - if (has_horizontal_blank_arc) { + // Typically, s_range == 5, i.e. 5 states for each time step. + // the 4th state(0-based index) ONLY has a horizontal blank arc. + // While state [0, 1, 2, 3] have a vertial arc + // and MAYBE a horizontal blank arc. + if (token_idx != s_range - 1 && has_horizontal_blank_arc) { s2c_row_splits_data[state_idx01] = 2; - } - } else { - // Typically, s_range == 5, - // the [4] state for each time step have and only have a horizontal blank arc. - s2c_row_splits_data[state_idx01] = 1; } - }); - + }); ExclusiveSum(s2c_row_splits, &s2c_row_splits); RaggedShape s2a_shape = RaggedShape2(&s2c_row_splits, nullptr, -1); + // ofsa_shape: output_fsa_shape. RaggedShape ofsa_shape = ComposeRaggedShapes(f2s_shape, s2a_shape); + int32_t num_arcs = ofsa_shape.NumElements(); Array1 arcs(context, num_arcs); - Array1 out_map(context, num_arcs); - int32_t* out_map_data = out_map.Data(); + Arc *arcs_data = arcs.Data(); const int32_t *row_splits1_data = ofsa_shape.RowSplits(1).Data(), *row_ids1_data = ofsa_shape.RowIds(1).Data(), *row_splits2_data = ofsa_shape.RowSplits(2).Data(), *row_ids2_data = ofsa_shape.RowIds(2).Data(); - int32_t symbols_stride_0 = symbols.stride(0), - symbols_stride_1 = symbols.stride(1); - K2_EVAL( - context, num_arcs, lambda_set_arcs, (int32_t arc_idx012)->void { - int32_t state_idx01 = row_ids2_data[arc_idx012], - fsa_idx0 = row_ids1_data[state_idx01], - state_idx0x = row_splits1_data[fsa_idx0], - arc_idx01x = row_splits2_data[state_idx01], - state_idx1 = state_idx01 - state_idx0x, - arc_idx2 = arc_idx012 - arc_idx01x, - t = state_idx1 / s_range, - token_index = state_idx1 % s_range; // token_index is belong to [0, U) - Arc arc; - if (state_idx1 == frames_a[fsa_idx0] * s_range - 1) { - arc.src_state = state_idx1; - arc.dest_state = state_idx1 + 1; - arc.label = -1; - arc.score = 0.0; - arcs_data[arc_idx012] = arc; - out_map_data[arc_idx012] = -1; - return; - } - int32_t next_state_idx1, logits_offset; - arc.src_state = state_idx1; - // arc. - if (token_index < s_range - 1) { - int32_t actual_u = ranges_a[fsa_idx0][t][token_index]; - int32_t symbols_offset = fsa_idx0 * symbols_stride_0 + actual_u * symbols_stride_1; - int32_t arc_label = symbols_data[symbols_offset]; - switch (arc_idx2) { - case 0: + Array1 out_map(context, num_arcs); + int32_t* out_map_data = out_map.Data(); + // Used to populate out_map. + const int32_t lg_stride_0 = logits.stride(0), + lg_stride_1 = logits.stride(1), + lg_stride_2 = logits.stride(2), + lg_stride_3 = logits.stride(3); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + logits.scalar_type(), "pruned_ranges_to_lattice", ([&] { + auto logits_a = logits.accessor(); + + K2_EVAL( + context, num_arcs, lambda_set_arcs, (int32_t arc_idx012)->void { + const int32_t state_idx01 = row_ids2_data[arc_idx012], + fsa_idx0 = row_ids1_data[state_idx01], + state_idx0x = row_splits1_data[fsa_idx0], + arc_idx01x = row_splits2_data[state_idx01], + state_idx1 = state_idx01 - state_idx0x, + arc_idx2 = arc_idx012 - arc_idx01x, + t = state_idx1 / s_range, + // token_idx lies within interval [0, s_range) + // but does not include s_range. + token_idx = state_idx1 % s_range; + Arc arc; + arc.src_state = state_idx1; + // The penultimate state only has a leaving arc + // pointing to super final state. + if (state_idx1 == frames_a[fsa_idx0] * s_range - 1) { + arc.src_state = state_idx1; arc.dest_state = state_idx1 + 1; - arc.label = arc_label; - arc.score = logits_a[fsa_idx0][t][token_index][arc_label]; - - logits_offset = fsa_idx0 * lg_stride_0 + t * lg_stride_1 + token_index * lg_stride_2 + arc_label * lg_stride_3; - out_map_data[arc_idx012] = logits_offset; - break; - case 1: - next_state_idx1 = ranges_a[fsa_idx0][t][token_index] - ranges_a[fsa_idx0][t + 1][0]; - arc.dest_state = next_state_idx1 + (t + 1) * s_range; + arc.label = -1; + arc.score = 0.0; + arcs_data[arc_idx012] = arc; + out_map_data[arc_idx012] = -1; + return; + } + + if (token_idx < s_range - 1) { + // States have a vertal arc with non-blank label and + // MAYBE a horizontal arc with blank label. + const int32_t symbol_idx = ranges_a[fsa_idx0][t][token_idx], + arc_label = symbols_a[fsa_idx0][symbol_idx]; + K2_CHECK_LE(arc_idx2, 2); + switch (arc_idx2) { + // For vertial arc with non-blank label. + case 0: + arc.dest_state = state_idx1 + 1; + arc.label = arc_label; + arc.score = logits_a[fsa_idx0][t][token_idx][arc_label]; + + out_map_data[arc_idx012] = + fsa_idx0 * lg_stride_0 + t * lg_stride_1 + + token_idx * lg_stride_2 + arc_label * lg_stride_3; + break; + // For horizontal arc with blank label. + case 1: + const int32_t dest_state_token_idx = + ranges_a[fsa_idx0][t][token_idx] - + ranges_a[fsa_idx0][t + 1][0]; + K2_CHECK_GE(dest_state_token_idx, 0); + arc.dest_state = dest_state_token_idx + (t + 1) * s_range; + arc.label = 0; + arc.score = logits_a[fsa_idx0][t][token_idx][0]; + + out_map_data[arc_idx012] = + fsa_idx0 * lg_stride_0 + t * lg_stride_1 + + token_idx * lg_stride_2; + break; + } + } else { + // States only have a horizontal arc with blank label. + K2_CHECK_EQ(arc_idx2, 0); + const int32_t dest_state_token_idx = + ranges_a[fsa_idx0][t][token_idx] - + ranges_a[fsa_idx0][t + 1][0]; + arc.dest_state = + dest_state_token_idx + (t + 1) * s_range; arc.label = 0; - arc.score = logits_a[fsa_idx0][t][token_index][0]; - - logits_offset = fsa_idx0 * lg_stride_0 + t * lg_stride_1 + token_index * lg_stride_2; - out_map_data[arc_idx012] = logits_offset; - break; - default: - K2_LOG(FATAL) << "Arc index must be less than 3"; - } - } else { - K2_CHECK_EQ(arc_idx2, 0); - next_state_idx1 = ranges_a[fsa_idx0][t][token_index] - ranges_a[fsa_idx0][t + 1][0]; - arc.dest_state = next_state_idx1 + (t + 1) * s_range; - arc.label = 0; - arc.score = logits_a[fsa_idx0][t][token_index][0]; - - logits_offset = fsa_idx0 * lg_stride_0 + t * lg_stride_1 + token_index * lg_stride_2; - out_map_data[arc_idx012] = logits_offset; - } - arcs_data[arc_idx012] = arc; - }); - *arc_map = std::move(out_map); + arc.score = logits_a[fsa_idx0][t][token_idx][0]; + + out_map_data[arc_idx012] = + fsa_idx0 * lg_stride_0 + t * lg_stride_1 + + token_idx * lg_stride_2; + } + arcs_data[arc_idx012] = arc; + }); + *arc_map = std::move(out_map); + })); + return Ragged(ofsa_shape, arcs); } From 22189049d2c07f6058c59d389999deb8c6299e49 Mon Sep 17 00:00:00 2001 From: Guo Liyong Date: Mon, 8 May 2023 10:41:02 +0800 Subject: [PATCH 12/18] move to python/csrc/torch --- k2/csrc/CMakeLists.txt | 1 - k2/csrc/pruned_ranges_to_lattice.cu | 232 ------------------ k2/csrc/pruned_ranges_to_lattice.h | 41 ---- .../csrc/torch/pruned_ranges_to_lattice.cu | 208 +++++++++++++++- .../csrc/torch/pruned_ranges_to_lattice.h | 11 + 5 files changed, 218 insertions(+), 275 deletions(-) delete mode 100644 k2/csrc/pruned_ranges_to_lattice.cu delete mode 100644 k2/csrc/pruned_ranges_to_lattice.h diff --git a/k2/csrc/CMakeLists.txt b/k2/csrc/CMakeLists.txt index 73b75eb4d..736668e9b 100644 --- a/k2/csrc/CMakeLists.txt +++ b/k2/csrc/CMakeLists.txt @@ -74,7 +74,6 @@ set(context_srcs reverse.cu rm_epsilon.cu rnnt_decode.cu - pruned_ranges_to_lattice.cu tensor.cu tensor_ops.cu thread_pool.cu diff --git a/k2/csrc/pruned_ranges_to_lattice.cu b/k2/csrc/pruned_ranges_to_lattice.cu deleted file mode 100644 index 480b61ff0..000000000 --- a/k2/csrc/pruned_ranges_to_lattice.cu +++ /dev/null @@ -1,232 +0,0 @@ -/** - * Copyright 2022 Xiaomi Corporation (authors: Liyong Guo) - * - * See LICENSE for clarification regarding multiple authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#include -#include - -#include "k2/csrc/pruned_ranges_to_lattice.h" -#include "k2/python/csrc/torch.h" - -namespace k2 { - -FsaVec PrunedRangesToLattice( - torch::Tensor ranges, // [B][T][s_range] - torch::Tensor frames, // [B] - torch::Tensor symbols, // [B][S] - torch::Tensor logits, // [B][S][s_range][C] - Array1 *arc_map) { - - TORCH_CHECK(ranges.get_device() == frames.get_device()); - TORCH_CHECK(ranges.get_device() == symbols.get_device()); - TORCH_CHECK(ranges.get_device() == logits.get_device()); - - TORCH_CHECK(ranges.dim() == 3, "ranges should be 3-dimensional"); - TORCH_CHECK(frames.dim() == 1, "frames should be 1-dimensional"); - TORCH_CHECK(symbols.dim() == 2, "symbols should be 2-dimensional"); - TORCH_CHECK(logits.dim() == 4, "logits should be 4-dimensional"); - - TORCH_CHECK(torch::kInt == ranges.scalar_type()); - TORCH_CHECK(torch::kInt == frames.scalar_type()); - TORCH_CHECK(torch::kInt == symbols.scalar_type()); - - ContextPtr context; - if (ranges.device().type() == torch::kCPU) { - context = GetCpuContext(); - } else if (ranges.is_cuda()) { - context = GetCudaContext(ranges.device().index()); - } else { - K2_LOG(FATAL) << "Unsupported device: " << ranges.device() - << "\nOnly CPU and CUDA are verified"; - } - - // "_a" is short for accessor. - auto ranges_a = ranges.accessor(); - auto frames_a = frames.accessor(); - auto symbols_a = symbols.accessor(); - - // Typically, s_range is 5. - const int32_t B = ranges.size(0), - T = ranges.size(1), - s_range = ranges.size(2); - - // Compute f2s_shape: fsa_to_state_shape. - Array1 f2s_row_splits(context, B + 1); - int32_t * f2s_row_splits_data = f2s_row_splits.Data(); - K2_EVAL(context, B, lambda_set_f2s_row_splits, (int32_t fsa_idx0) { - int32_t t = frames_a[fsa_idx0]; - K2_CHECK_LE(t, T); - // + 1 in "t * s_range + 1" is for super-final state. - f2s_row_splits_data[fsa_idx0] = t * s_range + 1; - }); - - ExclusiveSum(f2s_row_splits, &f2s_row_splits); - RaggedShape f2s_shape = - RaggedShape2(&f2s_row_splits, nullptr, -1); - - // Compute s2a_shape: state_to_arc_shape. - int32_t num_states = f2s_shape.NumElements(); - Array1 s2c_row_splits(context, num_states + 1); - int32_t *s2c_row_splits_data = s2c_row_splits.Data(); - const int32_t *f2s_row_splits1_data = f2s_shape.RowSplits(1).Data(), - *f2s_row_ids1_data = f2s_shape.RowIds(1).Data(); - // Compute number of arcs for each state. - K2_EVAL( - context, num_states, lambda_set_num_arcs, (int32_t state_idx01)->void { - int32_t fsa_idx0 = f2s_row_ids1_data[state_idx01], - state_idx0x = f2s_row_splits1_data[fsa_idx0], - state_idx1 = state_idx01 - state_idx0x, - t = state_idx1 / s_range, - token_idx = state_idx1 % s_range; - - K2_CHECK_LE(t, frames_a[fsa_idx0]); - - // The state doesn't have leaving arc: super final_state. - if (state_idx1 == frames_a[fsa_idx0] * s_range) { - s2c_row_splits_data[state_idx01] = 0; - return; - } - - // States have a leaving arc if no specially processed. - s2c_row_splits_data[state_idx01] = 1; - - // States have two leaving arcs. - bool has_horizontal_blank_arc = false; - if (t < frames_a[fsa_idx0] - 1) { - has_horizontal_blank_arc = - ranges_a[fsa_idx0][t][token_idx] >= ranges_a[fsa_idx0][t + 1][0]; - } - // Typically, s_range == 5, i.e. 5 states for each time step. - // the 4th state(0-based index) ONLY has a horizontal blank arc. - // While state [0, 1, 2, 3] have a vertial arc - // and MAYBE a horizontal blank arc. - if (token_idx != s_range - 1 && has_horizontal_blank_arc) { - s2c_row_splits_data[state_idx01] = 2; - } - }); - ExclusiveSum(s2c_row_splits, &s2c_row_splits); - RaggedShape s2a_shape = - RaggedShape2(&s2c_row_splits, nullptr, -1); - - // ofsa_shape: output_fsa_shape. - RaggedShape ofsa_shape = ComposeRaggedShapes(f2s_shape, s2a_shape); - - int32_t num_arcs = ofsa_shape.NumElements(); - Array1 arcs(context, num_arcs); - - Arc *arcs_data = arcs.Data(); - const int32_t *row_splits1_data = ofsa_shape.RowSplits(1).Data(), - *row_ids1_data = ofsa_shape.RowIds(1).Data(), - *row_splits2_data = ofsa_shape.RowSplits(2).Data(), - *row_ids2_data = ofsa_shape.RowIds(2).Data(); - - Array1 out_map(context, num_arcs); - int32_t* out_map_data = out_map.Data(); - // Used to populate out_map. - const int32_t lg_stride_0 = logits.stride(0), - lg_stride_1 = logits.stride(1), - lg_stride_2 = logits.stride(2), - lg_stride_3 = logits.stride(3); - - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - logits.scalar_type(), "pruned_ranges_to_lattice", ([&] { - auto logits_a = logits.accessor(); - - K2_EVAL( - context, num_arcs, lambda_set_arcs, (int32_t arc_idx012)->void { - const int32_t state_idx01 = row_ids2_data[arc_idx012], - fsa_idx0 = row_ids1_data[state_idx01], - state_idx0x = row_splits1_data[fsa_idx0], - arc_idx01x = row_splits2_data[state_idx01], - state_idx1 = state_idx01 - state_idx0x, - arc_idx2 = arc_idx012 - arc_idx01x, - t = state_idx1 / s_range, - // token_idx lies within interval [0, s_range) - // but does not include s_range. - token_idx = state_idx1 % s_range; - Arc arc; - arc.src_state = state_idx1; - // The penultimate state only has a leaving arc - // pointing to super final state. - if (state_idx1 == frames_a[fsa_idx0] * s_range - 1) { - arc.src_state = state_idx1; - arc.dest_state = state_idx1 + 1; - arc.label = -1; - arc.score = 0.0; - arcs_data[arc_idx012] = arc; - out_map_data[arc_idx012] = -1; - return; - } - - if (token_idx < s_range - 1) { - // States have a vertal arc with non-blank label and - // MAYBE a horizontal arc with blank label. - const int32_t symbol_idx = ranges_a[fsa_idx0][t][token_idx], - arc_label = symbols_a[fsa_idx0][symbol_idx]; - K2_CHECK_LE(arc_idx2, 2); - switch (arc_idx2) { - // For vertial arc with non-blank label. - case 0: - arc.dest_state = state_idx1 + 1; - arc.label = arc_label; - arc.score = logits_a[fsa_idx0][t][token_idx][arc_label]; - - out_map_data[arc_idx012] = - fsa_idx0 * lg_stride_0 + t * lg_stride_1 + - token_idx * lg_stride_2 + arc_label * lg_stride_3; - break; - // For horizontal arc with blank label. - case 1: - const int32_t dest_state_token_idx = - ranges_a[fsa_idx0][t][token_idx] - - ranges_a[fsa_idx0][t + 1][0]; - K2_CHECK_GE(dest_state_token_idx, 0); - arc.dest_state = dest_state_token_idx + (t + 1) * s_range; - arc.label = 0; - arc.score = logits_a[fsa_idx0][t][token_idx][0]; - - out_map_data[arc_idx012] = - fsa_idx0 * lg_stride_0 + t * lg_stride_1 + - token_idx * lg_stride_2; - break; - } - } else { - // States only have a horizontal arc with blank label. - K2_CHECK_EQ(arc_idx2, 0); - const int32_t dest_state_token_idx = - ranges_a[fsa_idx0][t][token_idx] - - ranges_a[fsa_idx0][t + 1][0]; - arc.dest_state = - dest_state_token_idx + (t + 1) * s_range; - arc.label = 0; - arc.score = logits_a[fsa_idx0][t][token_idx][0]; - - out_map_data[arc_idx012] = - fsa_idx0 * lg_stride_0 + t * lg_stride_1 + - token_idx * lg_stride_2; - } - arcs_data[arc_idx012] = arc; - }); - *arc_map = std::move(out_map); - })); - - return Ragged(ofsa_shape, arcs); -} - -} // namespace k2 diff --git a/k2/csrc/pruned_ranges_to_lattice.h b/k2/csrc/pruned_ranges_to_lattice.h deleted file mode 100644 index c48499e27..000000000 --- a/k2/csrc/pruned_ranges_to_lattice.h +++ /dev/null @@ -1,41 +0,0 @@ -/** - * @copyright - * Copyright 2022 Xiaomi Corporation (authors: Liyong Guo) - * - * @copyright - * See LICENSE for clarification regarding multiple authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef K2_CSRC_PRUNED_RANGES_TO_LATTICE_H_ -#define K2_CSRC_PRUNED_RANGES_TO_LATTICE_H_ - -#include - -#include - -#include "k2/python/csrc/torch.h" - -namespace k2 { - -FsaVec PrunedRangesToLattice( - torch::Tensor ranges, // [B][T][s_range] - torch::Tensor frames, // [B][T] - torch::Tensor symbols, // [B][S] - torch::Tensor logits, // [B][T][s_range][C] - Array1 *arc_map); - -} // namespace k2 - -#endif // K2_CSRC_PRUNED_RANGES_TO_LATTICE_H_ diff --git a/k2/python/csrc/torch/pruned_ranges_to_lattice.cu b/k2/python/csrc/torch/pruned_ranges_to_lattice.cu index 53c3ed02e..d0b7431ba 100644 --- a/k2/python/csrc/torch/pruned_ranges_to_lattice.cu +++ b/k2/python/csrc/torch/pruned_ranges_to_lattice.cu @@ -21,11 +21,217 @@ #include "k2/csrc/device_guard.h" #include "k2/csrc/fsa.h" #include "k2/csrc/torch_util.h" -#include "k2/csrc/pruned_ranges_to_lattice.h" #include "k2/python/csrc/torch/pruned_ranges_to_lattice.h" #include "k2/python/csrc/torch/v2/ragged_any.h" +namespace k2 { + +FsaVec PrunedRangesToLattice( + torch::Tensor ranges, // [B][T][s_range] + torch::Tensor frames, // [B] + torch::Tensor symbols, // [B][S] + torch::Tensor logits, // [B][S][s_range][C] + Array1 *arc_map) { + + TORCH_CHECK(ranges.get_device() == frames.get_device()); + TORCH_CHECK(ranges.get_device() == symbols.get_device()); + TORCH_CHECK(ranges.get_device() == logits.get_device()); + + TORCH_CHECK(ranges.dim() == 3, "ranges should be 3-dimensional"); + TORCH_CHECK(frames.dim() == 1, "frames should be 1-dimensional"); + TORCH_CHECK(symbols.dim() == 2, "symbols should be 2-dimensional"); + TORCH_CHECK(logits.dim() == 4, "logits should be 4-dimensional"); + + TORCH_CHECK(torch::kInt == ranges.scalar_type()); + TORCH_CHECK(torch::kInt == frames.scalar_type()); + TORCH_CHECK(torch::kInt == symbols.scalar_type()); + + ContextPtr context; + if (ranges.device().type() == torch::kCPU) { + context = GetCpuContext(); + } else if (ranges.is_cuda()) { + context = GetCudaContext(ranges.device().index()); + } else { + K2_LOG(FATAL) << "Unsupported device: " << ranges.device() + << "\nOnly CPU and CUDA are verified"; + } + + // "_a" is short for accessor. + auto ranges_a = ranges.accessor(); + auto frames_a = frames.accessor(); + auto symbols_a = symbols.accessor(); + + // Typically, s_range is 5. + const int32_t B = ranges.size(0), + T = ranges.size(1), + s_range = ranges.size(2); + + // Compute f2s_shape: fsa_to_state_shape. + Array1 f2s_row_splits(context, B + 1); + int32_t * f2s_row_splits_data = f2s_row_splits.Data(); + K2_EVAL(context, B, lambda_set_f2s_row_splits, (int32_t fsa_idx0) { + int32_t t = frames_a[fsa_idx0]; + K2_CHECK_LE(t, T); + // + 1 in "t * s_range + 1" is for super-final state. + f2s_row_splits_data[fsa_idx0] = t * s_range + 1; + }); + + ExclusiveSum(f2s_row_splits, &f2s_row_splits); + RaggedShape f2s_shape = + RaggedShape2(&f2s_row_splits, nullptr, -1); + + // Compute s2a_shape: state_to_arc_shape. + int32_t num_states = f2s_shape.NumElements(); + Array1 s2c_row_splits(context, num_states + 1); + int32_t *s2c_row_splits_data = s2c_row_splits.Data(); + const int32_t *f2s_row_splits1_data = f2s_shape.RowSplits(1).Data(), + *f2s_row_ids1_data = f2s_shape.RowIds(1).Data(); + // Compute number of arcs for each state. + K2_EVAL( + context, num_states, lambda_set_num_arcs, (int32_t state_idx01)->void { + int32_t fsa_idx0 = f2s_row_ids1_data[state_idx01], + state_idx0x = f2s_row_splits1_data[fsa_idx0], + state_idx1 = state_idx01 - state_idx0x, + t = state_idx1 / s_range, + token_idx = state_idx1 % s_range; + + K2_CHECK_LE(t, frames_a[fsa_idx0]); + + // The state doesn't have leaving arc: super final_state. + if (state_idx1 == frames_a[fsa_idx0] * s_range) { + s2c_row_splits_data[state_idx01] = 0; + return; + } + + // States have a leaving arc if no specially processed. + s2c_row_splits_data[state_idx01] = 1; + + // States have two leaving arcs. + bool has_horizontal_blank_arc = false; + if (t < frames_a[fsa_idx0] - 1) { + has_horizontal_blank_arc = + ranges_a[fsa_idx0][t][token_idx] >= ranges_a[fsa_idx0][t + 1][0]; + } + // Typically, s_range == 5, i.e. 5 states for each time step. + // the 4th state(0-based index) ONLY has a horizontal blank arc. + // While state [0, 1, 2, 3] have a vertial arc + // and MAYBE a horizontal blank arc. + if (token_idx != s_range - 1 && has_horizontal_blank_arc) { + s2c_row_splits_data[state_idx01] = 2; + } + }); + ExclusiveSum(s2c_row_splits, &s2c_row_splits); + RaggedShape s2a_shape = + RaggedShape2(&s2c_row_splits, nullptr, -1); + + // ofsa_shape: output_fsa_shape. + RaggedShape ofsa_shape = ComposeRaggedShapes(f2s_shape, s2a_shape); + + int32_t num_arcs = ofsa_shape.NumElements(); + Array1 arcs(context, num_arcs); + + Arc *arcs_data = arcs.Data(); + const int32_t *row_splits1_data = ofsa_shape.RowSplits(1).Data(), + *row_ids1_data = ofsa_shape.RowIds(1).Data(), + *row_splits2_data = ofsa_shape.RowSplits(2).Data(), + *row_ids2_data = ofsa_shape.RowIds(2).Data(); + + Array1 out_map(context, num_arcs); + int32_t* out_map_data = out_map.Data(); + // Used to populate out_map. + const int32_t lg_stride_0 = logits.stride(0), + lg_stride_1 = logits.stride(1), + lg_stride_2 = logits.stride(2), + lg_stride_3 = logits.stride(3); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + logits.scalar_type(), "pruned_ranges_to_lattice", ([&] { + auto logits_a = logits.accessor(); + + K2_EVAL( + context, num_arcs, lambda_set_arcs, (int32_t arc_idx012)->void { + const int32_t state_idx01 = row_ids2_data[arc_idx012], + fsa_idx0 = row_ids1_data[state_idx01], + state_idx0x = row_splits1_data[fsa_idx0], + arc_idx01x = row_splits2_data[state_idx01], + state_idx1 = state_idx01 - state_idx0x, + arc_idx2 = arc_idx012 - arc_idx01x, + t = state_idx1 / s_range, + // token_idx lies within interval [0, s_range) + // but does not include s_range. + token_idx = state_idx1 % s_range; + Arc arc; + arc.src_state = state_idx1; + // The penultimate state only has a leaving arc + // pointing to super final state. + if (state_idx1 == frames_a[fsa_idx0] * s_range - 1) { + arc.src_state = state_idx1; + arc.dest_state = state_idx1 + 1; + arc.label = -1; + arc.score = 0.0; + arcs_data[arc_idx012] = arc; + out_map_data[arc_idx012] = -1; + return; + } + + if (token_idx < s_range - 1) { + // States have a vertal arc with non-blank label and + // MAYBE a horizontal arc with blank label. + const int32_t symbol_idx = ranges_a[fsa_idx0][t][token_idx], + arc_label = symbols_a[fsa_idx0][symbol_idx]; + K2_CHECK_LE(arc_idx2, 2); + switch (arc_idx2) { + // For vertial arc with non-blank label. + case 0: + arc.dest_state = state_idx1 + 1; + arc.label = arc_label; + arc.score = logits_a[fsa_idx0][t][token_idx][arc_label]; + + out_map_data[arc_idx012] = + fsa_idx0 * lg_stride_0 + t * lg_stride_1 + + token_idx * lg_stride_2 + arc_label * lg_stride_3; + break; + // For horizontal arc with blank label. + case 1: + const int32_t dest_state_token_idx = + ranges_a[fsa_idx0][t][token_idx] - + ranges_a[fsa_idx0][t + 1][0]; + K2_CHECK_GE(dest_state_token_idx, 0); + arc.dest_state = dest_state_token_idx + (t + 1) * s_range; + arc.label = 0; + arc.score = logits_a[fsa_idx0][t][token_idx][0]; + + out_map_data[arc_idx012] = + fsa_idx0 * lg_stride_0 + t * lg_stride_1 + + token_idx * lg_stride_2; + break; + } + } else { + // States only have a horizontal arc with blank label. + K2_CHECK_EQ(arc_idx2, 0); + const int32_t dest_state_token_idx = + ranges_a[fsa_idx0][t][token_idx] - + ranges_a[fsa_idx0][t + 1][0]; + arc.dest_state = + dest_state_token_idx + (t + 1) * s_range; + arc.label = 0; + arc.score = logits_a[fsa_idx0][t][token_idx][0]; + + out_map_data[arc_idx012] = + fsa_idx0 * lg_stride_0 + t * lg_stride_1 + + token_idx * lg_stride_2; + } + arcs_data[arc_idx012] = arc; + }); + *arc_map = std::move(out_map); + })); + + return Ragged(ofsa_shape, arcs); +} + +} // namespace k2 + void PybindPrunedRangesToLattice(py::module &m) { m.def( "pruned_ranges_to_lattice", diff --git a/k2/python/csrc/torch/pruned_ranges_to_lattice.h b/k2/python/csrc/torch/pruned_ranges_to_lattice.h index b64f59bf0..2a0c3430b 100644 --- a/k2/python/csrc/torch/pruned_ranges_to_lattice.h +++ b/k2/python/csrc/torch/pruned_ranges_to_lattice.h @@ -23,6 +23,17 @@ #include "k2/python/csrc/torch.h" +namespace k2 { + +FsaVec PrunedRangesToLattice( + torch::Tensor ranges, // [B][T][s_range] + torch::Tensor frames, // [B][T] + torch::Tensor symbols, // [B][S] + torch::Tensor logits, // [B][T][s_range][C] + Array1 *arc_map); + +} // namespace k2 + void PybindPrunedRangesToLattice(py::module &m); #endif // K2_PYTHON_CSRC_TORCH_PRUNE_RANGE_TO_LATTICE_H_ From 44e192762a23dfd8431ebab96ed17370f870e4ef Mon Sep 17 00:00:00 2001 From: Guo Liyong Date: Mon, 8 May 2023 12:15:58 +0800 Subject: [PATCH 13/18] comments of PrunedRangesToLattice --- k2/python/csrc/torch/CMakeLists.txt | 2 +- .../csrc/torch/pruned_ranges_to_lattice.cu | 22 +++++++------ .../csrc/torch/pruned_ranges_to_lattice.h | 31 ++++++++++++++++--- 3 files changed, 41 insertions(+), 14 deletions(-) diff --git a/k2/python/csrc/torch/CMakeLists.txt b/k2/python/csrc/torch/CMakeLists.txt index 788db692f..73b93e6cc 100644 --- a/k2/python/csrc/torch/CMakeLists.txt +++ b/k2/python/csrc/torch/CMakeLists.txt @@ -1,6 +1,5 @@ # please keep the list sorted set(torch_srcs - pruned_ranges_to_lattice.cu arc.cu fsa.cu fsa_algo.cu @@ -9,6 +8,7 @@ set(torch_srcs mutual_information.cu mutual_information_cpu.cu nbest.cu + pruned_ranges_to_lattice.cu ragged.cu ragged_ops.cu rnnt_decode.cu diff --git a/k2/python/csrc/torch/pruned_ranges_to_lattice.cu b/k2/python/csrc/torch/pruned_ranges_to_lattice.cu index d0b7431ba..a33e781e7 100644 --- a/k2/python/csrc/torch/pruned_ranges_to_lattice.cu +++ b/k2/python/csrc/torch/pruned_ranges_to_lattice.cu @@ -18,6 +18,8 @@ * limitations under the License. */ +#include + #include "k2/csrc/device_guard.h" #include "k2/csrc/fsa.h" #include "k2/csrc/torch_util.h" @@ -45,7 +47,7 @@ FsaVec PrunedRangesToLattice( TORCH_CHECK(torch::kInt == ranges.scalar_type()); TORCH_CHECK(torch::kInt == frames.scalar_type()); - TORCH_CHECK(torch::kInt == symbols.scalar_type()); + TORCH_CHECK(torch::kLong == symbols.scalar_type()); ContextPtr context; if (ranges.device().type() == torch::kCPU) { @@ -60,7 +62,7 @@ FsaVec PrunedRangesToLattice( // "_a" is short for accessor. auto ranges_a = ranges.accessor(); auto frames_a = frames.accessor(); - auto symbols_a = symbols.accessor(); + auto symbols_a = symbols.accessor(); // Typically, s_range is 5. const int32_t B = ranges.size(0), @@ -235,14 +237,16 @@ FsaVec PrunedRangesToLattice( void PybindPrunedRangesToLattice(py::module &m) { m.def( "pruned_ranges_to_lattice", - [](torch::Tensor ranges, torch::Tensor x_lens, - torch::Tensor y, + [](torch::Tensor ranges, + torch::Tensor frames, + torch::Tensor symbols, torch::Tensor logits) -> std::pair { k2::DeviceGuard guard(k2::GetContext(ranges)); - k2::Array1 label_map; - k2::FsaVec ofsa = k2::PrunedRangesToLattice(ranges, x_lens, y, logits, &label_map); - torch::Tensor tensor = ToTorch(label_map); - return std::make_pair(ofsa, tensor); + k2::Array1 arc_to_logit_map; + k2::FsaVec ofsa = k2::PrunedRangesToLattice( + ranges, frames, symbols, logits, &arc_to_logit_map); + return std::make_pair(ofsa, ToTorch(arc_to_logit_map)); }, - py::arg("ranges"), py::arg("x_lens"), py::arg("y"), py::arg("logits")); + py::arg("ranges"), py::arg("frames"), + py::arg("symbols"), py::arg("logits")); } diff --git a/k2/python/csrc/torch/pruned_ranges_to_lattice.h b/k2/python/csrc/torch/pruned_ranges_to_lattice.h index 2a0c3430b..e66c7243e 100644 --- a/k2/python/csrc/torch/pruned_ranges_to_lattice.h +++ b/k2/python/csrc/torch/pruned_ranges_to_lattice.h @@ -18,16 +18,39 @@ * limitations under the License. */ -#ifndef K2_PYTHON_CSRC_TORCH_PRUNE_RANGE_TO_LATTICE_H_ -#define K2_PYTHON_CSRC_TORCH_PRUNE_RANGE_TO_LATTICE_H_ +#ifndef K2_PYTHON_CSRC_TORCH_PRUNED_RANGES_TO_LATTICE_H_ +#define K2_PYTHON_CSRC_TORCH_PRUNED_RANGES_TO_LATTICE_H_ #include "k2/python/csrc/torch.h" namespace k2 { +/* + Convert pruned ranges to lattice while also supporting autograd. + + The input pruned ranges is normally generated by `get_rnnt_prune_ranges`. + See k2/python/k2/rnnt_loss.py for the process of generating ranges and + the information it represents. + + When this is implemented, the lattice is used to generate force-alignment. + Perhaps you could find other uses for this function. + + @param ranges A tensor containing the symbol indexes for each frame. + Its shape is (B, T, s_range). See the docs in `get_rnnt_prune_ranges` + in k2/python/k2/rnnt_loss.py for more details of this tensor. + @param frames The number of frames per sample with shape (B). + @param symbols The symbol sequence, a LongTensor of shape (B, S), + and elements in {0..C-1}. + @param logits The pruned joiner network (or am/lm) + of shape (B, T, s_range, C) + @param [out] arc_map A map from arcs in generated lattice to global index + of logits, or -1 if the arc had no corresponding score in logits, + e.g. arc pointing to super final state. + @return Return an FsaVec, which contains the generated lattice. +*/ FsaVec PrunedRangesToLattice( torch::Tensor ranges, // [B][T][s_range] - torch::Tensor frames, // [B][T] + torch::Tensor frames, // [B] torch::Tensor symbols, // [B][S] torch::Tensor logits, // [B][T][s_range][C] Array1 *arc_map); @@ -36,4 +59,4 @@ FsaVec PrunedRangesToLattice( void PybindPrunedRangesToLattice(py::module &m); -#endif // K2_PYTHON_CSRC_TORCH_PRUNE_RANGE_TO_LATTICE_H_ +#endif // K2_PYTHON_CSRC_TORCH_PRUNED_RANGES_TO_LATTICE_H_ From 15e645b403b341514316741c0e686ff7fe6d2c5b Mon Sep 17 00:00:00 2001 From: Guo Liyong Date: Mon, 8 May 2023 15:16:53 +0800 Subject: [PATCH 14/18] unittest of pruned_ranges_to_lattice --- .../csrc/torch/pruned_ranges_to_lattice.h | 7 +- .../tests/pruned_ranges_to_lattice_test.py | 192 ++++++++++++++++++ 2 files changed, 198 insertions(+), 1 deletion(-) create mode 100644 k2/python/tests/pruned_ranges_to_lattice_test.py diff --git a/k2/python/csrc/torch/pruned_ranges_to_lattice.h b/k2/python/csrc/torch/pruned_ranges_to_lattice.h index e66c7243e..c34ae0f13 100644 --- a/k2/python/csrc/torch/pruned_ranges_to_lattice.h +++ b/k2/python/csrc/torch/pruned_ranges_to_lattice.h @@ -38,11 +38,16 @@ namespace k2 { @param ranges A tensor containing the symbol indexes for each frame. Its shape is (B, T, s_range). See the docs in `get_rnnt_prune_ranges` in k2/python/k2/rnnt_loss.py for more details of this tensor. + Its type is int32, consistent with that in rnnt_loss.py. @param frames The number of frames per sample with shape (B). + Its type is int32. @param symbols The symbol sequence, a LongTensor of shape (B, S), and elements in {0..C-1}. + Its type is int64(Long), consistent with that in rnnt_loss.py. @param logits The pruned joiner network (or am/lm) - of shape (B, T, s_range, C) + of shape (B, T, s_range, C). + Its type can be float32, float64, float16. Though float32 is mainly + used. float64 and float16 are also supported for future use. @param [out] arc_map A map from arcs in generated lattice to global index of logits, or -1 if the arc had no corresponding score in logits, e.g. arc pointing to super final state. diff --git a/k2/python/tests/pruned_ranges_to_lattice_test.py b/k2/python/tests/pruned_ranges_to_lattice_test.py new file mode 100644 index 000000000..bc6eb9a96 --- /dev/null +++ b/k2/python/tests/pruned_ranges_to_lattice_test.py @@ -0,0 +1,192 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (authors: Liyong Guo) +# +# See ../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# To run this single test, use +# +# ctest --verbose -R pruned_ranges_to_lattice_test_py + +import unittest + +import k2 +import torch + + +class TestPrunedRangesToLattice(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.devices = [torch.device("cpu")] + if torch.cuda.is_available() and k2.with_cuda: + cls.devices.append(torch.device("cuda", 0)) + if torch.cuda.device_count() > 1: + torch.cuda.set_device(1) + cls.devices.append(torch.device("cuda", 1)) + cls.float_dtypes = [torch.float32, torch.float64, torch.float16] + + def _common_test_part(self, ranges, frames, symbols, logits): + ofsa, arc_map = k2.pruned_ranges_to_lattice(ranges, frames, symbols, logits) + scores_tracked_by_autograd = k2.index_select( + logits.reshape(-1).to(torch.float32), arc_map + ) + lattice = k2.Fsa(ofsa) + assert torch.allclose( + lattice.scores.to(torch.float32), scores_tracked_by_autograd + ) + assert torch.allclose( + lattice.scores.to(torch.float32), + torch.tensor( + [ + 10.8000, + 11.7000, + 11.0500, + 12.6000, + 12.0500, + 13.5000, + 13.0500, + 14.0500, + 20.7000, + 21.6000, + 21.0500, + 22.5000, + 22.0500, + 23.4000, + 23.0500, + 24.0500, + 30.6000, + 31.5000, + 31.0500, + 32.4000, + 32.0500, + 33.3000, + 33.0500, + 34.0500, + 40.5000, + 41.4000, + 42.3000, + 43.2000, + 0.0000, + 50.2000, + 50.0500, + 51.3000, + 51.0500, + 52.4000, + 52.0500, + 53.5000, + 53.0500, + 54.0500, + 60.2000, + 61.3000, + 62.4000, + 63.5000, + 63.0500, + 64.0500, + 70.5000, + 71.6000, + 72.7000, + 73.8000, + 0.0000, + ] + ), + ) + + assert torch.allclose( + lattice.arcs.values()[:, :3], + torch.tensor( + [ + [0, 1, 8], + [1, 2, 7], + [1, 5, 0], + [2, 3, 6], + [2, 6, 0], + [3, 4, 5], + [3, 7, 0], + [4, 8, 0], + [5, 6, 7], + [6, 7, 6], + [6, 10, 0], + [7, 8, 5], + [7, 11, 0], + [8, 9, 4], + [8, 12, 0], + [9, 13, 0], + [10, 11, 6], + [11, 12, 5], + [11, 15, 0], + [12, 13, 4], + [12, 16, 0], + [13, 14, 3], + [13, 17, 0], + [14, 18, 0], + [15, 16, 5], + [16, 17, 4], + [17, 18, 3], + [18, 19, 2], + [19, 20, -1], + [0, 1, 2], + [0, 5, 0], + [1, 2, 3], + [1, 6, 0], + [2, 3, 4], + [2, 7, 0], + [3, 4, 5], + [3, 8, 0], + [4, 9, 0], + [5, 6, 2], + [6, 7, 3], + [7, 8, 4], + [8, 9, 5], + [8, 10, 0], + [9, 11, 0], + [10, 11, 5], + [11, 12, 6], + [12, 13, 7], + [13, 14, 8], + [14, 15, -1], + ], + dtype=torch.int32, + ), + ) + + def test(self): + ranges = torch.tensor( + [ + [[0, 1, 2, 3, 4], [1, 2, 3, 4, 5], [2, 3, 4, 5, 6], [3, 4, 5, 6, 7]], + [[0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [3, 4, 5, 6, 7], [3, 4, 5, 6, 7]], + ], + dtype=torch.int32, + ) + B, T, s_range = ranges.size() + C = 9 + + frames = torch.tensor([4, 3], dtype=torch.int32) + symbols = torch.tensor( + [[8, 7, 6, 5, 4, 3, 2], [2, 3, 4, 5, 6, 7, 8]], dtype=torch.long + ) + logits = torch.tensor( + [0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8], dtype=torch.float32 + ).expand(B, T, s_range, C) + logits = logits + torch.tensor( + [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0] + ).reshape(B, T, 1, 1) + logits = logits + torch.tensor([0.0, 1, 2, 3, 4]).reshape(1, 1, s_range, 1) + for dtype in self.float_dtypes: + tmp_logits = logits.to(dtype) + self._common_test_part(ranges, frames, symbols, logits) + + +if __name__ == "__main__": + unittest.main() From 00a5cbdbb53faa1ddb948d4f59bc5e36182b7813 Mon Sep 17 00:00:00 2001 From: Guo Liyong Date: Mon, 8 May 2023 15:25:24 +0800 Subject: [PATCH 15/18] add pruned_ranges_to_lattice_test into CMakeLists --- k2/python/tests/CMakeLists.txt | 1 + k2/python/tests/pruned_ranges_to_lattice_test.py | 10 +++++----- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/k2/python/tests/CMakeLists.txt b/k2/python/tests/CMakeLists.txt index 429af80b5..400ab930f 100644 --- a/k2/python/tests/CMakeLists.txt +++ b/k2/python/tests/CMakeLists.txt @@ -60,6 +60,7 @@ set(py_test_files nbest_test.py numerical_gradient_check_test.py online_dense_intersecter_test.py + pruned_ranges_to_lattice_test.py ragged_ops_test.py ragged_shape_test.py ragged_tensor_test.py diff --git a/k2/python/tests/pruned_ranges_to_lattice_test.py b/k2/python/tests/pruned_ranges_to_lattice_test.py index bc6eb9a96..60874436f 100644 --- a/k2/python/tests/pruned_ranges_to_lattice_test.py +++ b/k2/python/tests/pruned_ranges_to_lattice_test.py @@ -38,7 +38,7 @@ def setUpClass(cls): cls.float_dtypes = [torch.float32, torch.float64, torch.float16] def _common_test_part(self, ranges, frames, symbols, logits): - ofsa, arc_map = k2.pruned_ranges_to_lattice(ranges, frames, symbols, logits) + ofsa, arc_map = k2.pruned_ranges_to_lattice(ranges, frames, symbols, logits) # noqa scores_tracked_by_autograd = k2.index_select( logits.reshape(-1).to(torch.float32), arc_map ) @@ -164,8 +164,8 @@ def _common_test_part(self, ranges, frames, symbols, logits): def test(self): ranges = torch.tensor( [ - [[0, 1, 2, 3, 4], [1, 2, 3, 4, 5], [2, 3, 4, 5, 6], [3, 4, 5, 6, 7]], - [[0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [3, 4, 5, 6, 7], [3, 4, 5, 6, 7]], + [[0, 1, 2, 3, 4], [1, 2, 3, 4, 5], [2, 3, 4, 5, 6], [3, 4, 5, 6, 7]], # noqa + [[0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [3, 4, 5, 6, 7], [3, 4, 5, 6, 7]], # noqa ], dtype=torch.int32, ) @@ -182,10 +182,10 @@ def test(self): logits = logits + torch.tensor( [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0] ).reshape(B, T, 1, 1) - logits = logits + torch.tensor([0.0, 1, 2, 3, 4]).reshape(1, 1, s_range, 1) + logits = logits + torch.tensor([0.0, 1, 2, 3, 4]).reshape(1, 1, s_range, 1) # noqa for dtype in self.float_dtypes: tmp_logits = logits.to(dtype) - self._common_test_part(ranges, frames, symbols, logits) + self._common_test_part(ranges, frames, symbols, tmp_logits) if __name__ == "__main__": From 482915cc2dc4d8a1507900e6edf2959ca981bf23 Mon Sep 17 00:00:00 2001 From: Guo Liyong Date: Mon, 8 May 2023 16:06:41 +0800 Subject: [PATCH 16/18] remove whitespace --- k2/python/k2/fsa.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/k2/python/k2/fsa.py b/k2/python/k2/fsa.py index 8b0067998..965ed5ce2 100644 --- a/k2/python/k2/fsa.py +++ b/k2/python/k2/fsa.py @@ -453,7 +453,7 @@ def properties(self) -> int: " fsa.labels = labels" ) return properties # Return cached properties. - + self.labels_version = self.labels._version if self.arcs.num_axes() == 2: properties = _k2.get_fsa_basic_properties(self.arcs) From 5012a281de56a981e779bc53c51d0c7d194635c5 Mon Sep 17 00:00:00 2001 From: Guo Liyong Date: Mon, 8 May 2023 17:00:24 +0800 Subject: [PATCH 17/18] more comments and test arc_map --- .../csrc/torch/pruned_ranges_to_lattice.cu | 30 ++++ .../csrc/torch/pruned_ranges_to_lattice.h | 2 +- .../tests/pruned_ranges_to_lattice_test.py | 140 ++++++++++-------- 3 files changed, 109 insertions(+), 63 deletions(-) diff --git a/k2/python/csrc/torch/pruned_ranges_to_lattice.cu b/k2/python/csrc/torch/pruned_ranges_to_lattice.cu index a33e781e7..415be116a 100644 --- a/k2/python/csrc/torch/pruned_ranges_to_lattice.cu +++ b/k2/python/csrc/torch/pruned_ranges_to_lattice.cu @@ -29,6 +29,34 @@ namespace k2 { +/* + Convert pruned ranges to lattice while also supporting autograd. + + The input pruned ranges is normally generated by `get_rnnt_prune_ranges`. + See k2/python/k2/rnnt_loss.py for the process of generating ranges and + the information it represents. + + When this is implemented, the lattice is used to generate force-alignment. + Perhaps you could find other uses for this function. + + @param ranges A tensor containing the symbol indexes for each frame. + Its shape is (B, T, s_range). See the docs in `get_rnnt_prune_ranges` + in k2/python/k2/rnnt_loss.py for more details of this tensor. + Its type is int32, consistent with that in rnnt_loss.py. + @param frames The number of frames per sample with shape (B). + Its type is int32. + @param symbols The symbol sequence, a LongTensor of shape (B, S), + and elements in {0..C-1}. + Its type is int64(Long), consistent with that in rnnt_loss.py. + @param logits The pruned joiner network (or am/lm) + of shape (B, T, s_range, C). + Its type can be float32, float64, float16. Though float32 is mainly + used, float64 and float16 are also supported for future use. + @param [out] arc_map A map from arcs in generated lattice to global index + of logits, or -1 if the arc had no corresponding score in logits, + e.g. arc pointing to super final state. + @return Return an FsaVec, which contains the generated lattice. +*/ FsaVec PrunedRangesToLattice( torch::Tensor ranges, // [B][T][s_range] torch::Tensor frames, // [B] @@ -147,6 +175,8 @@ FsaVec PrunedRangesToLattice( lg_stride_2 = logits.stride(2), lg_stride_3 = logits.stride(3); + // Type of logits can be float32, float64, float16. Though float32 is mainly + // used, float64 and float16 are also supported for future use. AT_DISPATCH_FLOATING_TYPES_AND_HALF( logits.scalar_type(), "pruned_ranges_to_lattice", ([&] { auto logits_a = logits.accessor(); diff --git a/k2/python/csrc/torch/pruned_ranges_to_lattice.h b/k2/python/csrc/torch/pruned_ranges_to_lattice.h index c34ae0f13..0e02c3529 100644 --- a/k2/python/csrc/torch/pruned_ranges_to_lattice.h +++ b/k2/python/csrc/torch/pruned_ranges_to_lattice.h @@ -47,7 +47,7 @@ namespace k2 { @param logits The pruned joiner network (or am/lm) of shape (B, T, s_range, C). Its type can be float32, float64, float16. Though float32 is mainly - used. float64 and float16 are also supported for future use. + used, float64 and float16 are also supported for future use. @param [out] arc_map A map from arcs in generated lattice to global index of logits, or -1 if the arc had no corresponding score in logits, e.g. arc pointing to super final state. diff --git a/k2/python/tests/pruned_ranges_to_lattice_test.py b/k2/python/tests/pruned_ranges_to_lattice_test.py index 60874436f..395c21e4d 100644 --- a/k2/python/tests/pruned_ranges_to_lattice_test.py +++ b/k2/python/tests/pruned_ranges_to_lattice_test.py @@ -38,72 +38,78 @@ def setUpClass(cls): cls.float_dtypes = [torch.float32, torch.float64, torch.float16] def _common_test_part(self, ranges, frames, symbols, logits): - ofsa, arc_map = k2.pruned_ranges_to_lattice(ranges, frames, symbols, logits) # noqa + ofsa, arc_map = k2.pruned_ranges_to_lattice( + ranges, frames, symbols, logits + ) + + assert torch.equal( + arc_map, + torch.tensor( + [ + 8, + 16, + 9, + 24, + 18, + 32, + 27, + 36, + 52, + 60, + 54, + 68, + 63, + 76, + 72, + 81, + 96, + 104, + 99, + 112, + 108, + 120, + 117, + 126, + 140, + 148, + 156, + 164, + -1, + 182, + 180, + 192, + 189, + 202, + 198, + 212, + 207, + 216, + 227, + 237, + 247, + 257, + 252, + 261, + 275, + 285, + 295, + 305, + -1, + ], + dtype=torch.int32, + ) + ) + lattice = k2.Fsa(ofsa) + scores_tracked_by_autograd = k2.index_select( logits.reshape(-1).to(torch.float32), arc_map ) - lattice = k2.Fsa(ofsa) + assert torch.allclose( lattice.scores.to(torch.float32), scores_tracked_by_autograd ) - assert torch.allclose( - lattice.scores.to(torch.float32), - torch.tensor( - [ - 10.8000, - 11.7000, - 11.0500, - 12.6000, - 12.0500, - 13.5000, - 13.0500, - 14.0500, - 20.7000, - 21.6000, - 21.0500, - 22.5000, - 22.0500, - 23.4000, - 23.0500, - 24.0500, - 30.6000, - 31.5000, - 31.0500, - 32.4000, - 32.0500, - 33.3000, - 33.0500, - 34.0500, - 40.5000, - 41.4000, - 42.3000, - 43.2000, - 0.0000, - 50.2000, - 50.0500, - 51.3000, - 51.0500, - 52.4000, - 52.0500, - 53.5000, - 53.0500, - 54.0500, - 60.2000, - 61.3000, - 62.4000, - 63.5000, - 63.0500, - 64.0500, - 70.5000, - 71.6000, - 72.7000, - 73.8000, - 0.0000, - ] - ), - ) - assert torch.allclose( + assert torch.equal( lattice.arcs.values()[:, :3], torch.tensor( [ @@ -164,8 +170,18 @@ def _common_test_part(self, ranges, frames, symbols, logits): def test(self): ranges = torch.tensor( [ - [[0, 1, 2, 3, 4], [1, 2, 3, 4, 5], [2, 3, 4, 5, 6], [3, 4, 5, 6, 7]], # noqa - [[0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [3, 4, 5, 6, 7], [3, 4, 5, 6, 7]], # noqa + [ + [0, 1, 2, 3, 4], + [1, 2, 3, 4, 5], + [2, 3, 4, 5, 6], + [3, 4, 5, 6, 7], + ], + [ + [0, 1, 2, 3, 4], + [0, 1, 2, 3, 4], + [3, 4, 5, 6, 7], + [3, 4, 5, 6, 7], + ], ], dtype=torch.int32, ) @@ -182,7 +198,7 @@ def test(self): logits = logits + torch.tensor( [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0] ).reshape(B, T, 1, 1) - logits = logits + torch.tensor([0.0, 1, 2, 3, 4]).reshape(1, 1, s_range, 1) # noqa + logits = logits + torch.tensor([0.0, 1, 2, 3, 4]).reshape(1, 1, s_range, 1) for dtype in self.float_dtypes: tmp_logits = logits.to(dtype) self._common_test_part(ranges, frames, symbols, tmp_logits) From 4b9329c8f3fdf9f75d398d2d96422961734d2611 Mon Sep 17 00:00:00 2001 From: Guo Liyong Date: Mon, 8 May 2023 17:36:10 +0800 Subject: [PATCH 18/18] fix style check --- k2/python/tests/pruned_ranges_to_lattice_test.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/k2/python/tests/pruned_ranges_to_lattice_test.py b/k2/python/tests/pruned_ranges_to_lattice_test.py index 395c21e4d..ee95f7e00 100644 --- a/k2/python/tests/pruned_ranges_to_lattice_test.py +++ b/k2/python/tests/pruned_ranges_to_lattice_test.py @@ -97,7 +97,7 @@ def _common_test_part(self, ranges, frames, symbols, logits): -1, ], dtype=torch.int32, - ) + ), ) lattice = k2.Fsa(ofsa) @@ -198,7 +198,9 @@ def test(self): logits = logits + torch.tensor( [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0] ).reshape(B, T, 1, 1) - logits = logits + torch.tensor([0.0, 1, 2, 3, 4]).reshape(1, 1, s_range, 1) + logits = logits + torch.tensor([0.0, 1, 2, 3, 4]).reshape( + 1, 1, s_range, 1 + ) for dtype in self.float_dtypes: tmp_logits = logits.to(dtype) self._common_test_part(ranges, frames, symbols, tmp_logits)