2525import pytorch_lightning as pl
2626from typing import Optional
2727from omegaconf import DictConfig
28- from openspeech .data .audio .dataset import SpeechToTextDataset
2928
29+ from openspeech .data .audio .dataset import SpeechToTextDataset
3030from openspeech .datasets import register_data_module
31- from openspeech .data .sampler import BucketingSampler
31+ from openspeech .data .sampler import RandomSampler , SmartBatchingSampler
3232from openspeech .data .audio .data_loader import AudioDataLoader
3333from openspeech .datasets .ksponspeech .preprocess .preprocess import preprocess , preprocess_test_data
3434from openspeech .datasets .ksponspeech .preprocess .character import generate_character_script , generate_character_labels
3535from openspeech .datasets .ksponspeech .preprocess .grapheme import sentence_to_grapheme
3636from openspeech .datasets .ksponspeech .preprocess .subword import train_sentencepiece , sentence_to_subwords
37- from openspeech .tokenizers import TOKENIZER_REGISTRY
3837from openspeech .tokenizers .tokenizer import Tokenizer
3938
4039
@@ -49,6 +48,8 @@ class LightningKsponSpeechDataModule(pl.LightningDataModule):
4948
5049 Attributes:
5150 KSPONSPEECH_TRAIN_NUM (int): the number of KsponSpeech's train data.
51+ KSPONSPEECH_VALID_NUM (int): the number of KsponSpeech's validation data.
52+ KSPONSPEECH_TEST_NUM (int): the number of KsponSpeech's test data.
5253
5354 Args:
5455 configs (DictConfig): configuration set.
@@ -173,26 +174,26 @@ def setup(self, stage: Optional[str] = None, tokenizer: Tokenizer = None):
173174 )
174175
175176 def train_dataloader (self ) -> AudioDataLoader :
176- r""" Return data loader for training. """
177- train_sampler = BucketingSampler ( self .dataset ['train' ], batch_size = self .configs .trainer .batch_size )
177+ sampler = SmartBatchingSampler if self . configs . trainer . sampler == 'smart' else RandomSampler
178+ train_sampler = sampler ( data_source = self .dataset ['train' ], batch_size = self .configs .trainer .batch_size )
178179 return AudioDataLoader (
179180 dataset = self .dataset ['train' ],
180181 num_workers = self .configs .trainer .num_workers ,
181182 batch_sampler = train_sampler ,
182183 )
183184
184185 def val_dataloader (self ) -> AudioDataLoader :
185- r""" Return data loader for validation. """
186- valid_sampler = BucketingSampler (self .dataset ['valid' ], batch_size = self .configs .trainer .batch_size )
186+ sampler = SmartBatchingSampler if self . configs . trainer . sampler == 'smart' else RandomSampler
187+ valid_sampler = sampler (self .dataset ['valid' ], batch_size = self .configs .trainer .batch_size )
187188 return AudioDataLoader (
188189 dataset = self .dataset ['valid' ],
189190 num_workers = self .configs .trainer .num_workers ,
190191 batch_sampler = valid_sampler ,
191192 )
192193
193194 def test_dataloader (self ) -> AudioDataLoader :
194- r""" Return data loader for training. """
195- test_sampler = BucketingSampler (self .dataset ['test' ], batch_size = self .configs .trainer .batch_size )
195+ sampler = SmartBatchingSampler if self . configs . trainer . sampler == 'smart' else RandomSampler
196+ test_sampler = sampler (self .dataset ['test' ], batch_size = self .configs .trainer .batch_size )
196197 return AudioDataLoader (
197198 dataset = self .dataset ['test' ],
198199 num_workers = self .configs .trainer .num_workers ,
0 commit comments