Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions nemo/collections/common/data/lhotse/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import random
import warnings
from dataclasses import dataclass
from functools import partial
Expand Down Expand Up @@ -319,6 +320,7 @@ def get_lhotse_dataloader_from_config(
ReverbWithImpulseResponse(
rir_recordings=RecordingSet.from_file(config.rir_path) if config.rir_path is not None else None,
p=config.rir_prob,
randgen=random.Random(seed),
)
)

Expand Down
27 changes: 10 additions & 17 deletions tests/collections/common/test_lhotse_dataloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,6 @@
from nemo.collections.common.data.lhotse.text_adapters import TextExample
from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer, create_spt_model

requires_torchaudio = pytest.mark.skipif(
not lhotse.utils.is_torchaudio_available(), reason="Lhotse Shar format support requires torchaudio."
)


@pytest.fixture(scope="session")
def cutset_path(tmp_path_factory) -> Path:
Expand Down Expand Up @@ -348,7 +344,6 @@ def test_dataloader_from_lhotse_cuts_channel_selector(mc_cutset_path: Path):
assert torch.equal(b_cs["audio"], batches[n]["audio"][:, channel_selector, :])


@requires_torchaudio
def test_dataloader_from_lhotse_shar_cuts(cutset_shar_path: Path):
config = OmegaConf.create(
{
Expand Down Expand Up @@ -682,7 +677,6 @@ def test_dataloader_from_tarred_nemo_manifest_concat(nemo_tarred_manifest_path:
torch.testing.assert_close(b["audio_lens"], expected_audio_lens)


@requires_torchaudio
def test_dataloader_from_lhotse_shar_cuts_combine_datasets_unweighted(
cutset_shar_path: Path, cutset_shar_path_other: Path
):
Expand Down Expand Up @@ -723,19 +717,18 @@ def test_dataloader_from_lhotse_shar_cuts_combine_datasets_unweighted(
assert len([cid for cid in b["ids"] if cid.startswith("other")]) == 2 # dataset 2

b = batches[1]
assert len([cid for cid in b["ids"] if cid.startswith("dummy")]) == 2 # dataset 1
assert len([cid for cid in b["ids"] if cid.startswith("other")]) == 1 # dataset 2
assert len([cid for cid in b["ids"] if cid.startswith("dummy")]) == 0 # dataset 1
assert len([cid for cid in b["ids"] if cid.startswith("other")]) == 3 # dataset 2

b = batches[2]
assert len([cid for cid in b["ids"] if cid.startswith("dummy")]) == 1 # dataset 1
assert len([cid for cid in b["ids"] if cid.startswith("other")]) == 2 # dataset 2
assert len([cid for cid in b["ids"] if cid.startswith("dummy")]) == 2 # dataset 1
assert len([cid for cid in b["ids"] if cid.startswith("other")]) == 1 # dataset 2

b = batches[3]
assert len([cid for cid in b["ids"] if cid.startswith("dummy")]) == 1 # dataset 1
assert len([cid for cid in b["ids"] if cid.startswith("other")]) == 2 # dataset 2


@requires_torchaudio
def test_dataloader_from_lhotse_shar_cuts_combine_datasets_weighted(
cutset_shar_path: Path, cutset_shar_path_other: Path
):
Expand Down Expand Up @@ -776,12 +769,12 @@ def test_dataloader_from_lhotse_shar_cuts_combine_datasets_weighted(
assert len([cid for cid in b["ids"] if cid.startswith("other")]) == 0 # dataset 2

b = batches[1]
assert len([cid for cid in b["ids"] if cid.startswith("dummy")]) == 3 # dataset 1
assert len([cid for cid in b["ids"] if cid.startswith("other")]) == 0 # dataset 2
assert len([cid for cid in b["ids"] if cid.startswith("dummy")]) == 1 # dataset 1
assert len([cid for cid in b["ids"] if cid.startswith("other")]) == 2 # dataset 2

b = batches[2]
assert len([cid for cid in b["ids"] if cid.startswith("dummy")]) == 3 # dataset 1
assert len([cid for cid in b["ids"] if cid.startswith("other")]) == 0 # dataset 2
assert len([cid for cid in b["ids"] if cid.startswith("dummy")]) == 2 # dataset 1
assert len([cid for cid in b["ids"] if cid.startswith("other")]) == 1 # dataset 2

b = batches[3]
assert len([cid for cid in b["ids"] if cid.startswith("dummy")]) == 3 # dataset 1
Expand All @@ -792,8 +785,8 @@ def test_dataloader_from_lhotse_shar_cuts_combine_datasets_weighted(
assert len([cid for cid in b["ids"] if cid.startswith("other")]) == 0 # dataset 2

b = batches[5]
assert len([cid for cid in b["ids"] if cid.startswith("dummy")]) == 1 # dataset 1
assert len([cid for cid in b["ids"] if cid.startswith("other")]) == 2 # dataset 2
assert len([cid for cid in b["ids"] if cid.startswith("dummy")]) == 3 # dataset 1
assert len([cid for cid in b["ids"] if cid.startswith("other")]) == 0 # dataset 2


class TextDataset(torch.utils.data.Dataset):
Expand Down