Skip to content

Commit 0e77a6e

Browse files
authored
Merge pull request #83 from openspeech-team/uniform-length-batch
Add uniform-length batching (smart batching) [resolved #82] - Soohwan Kim
2 parents da427c2 + 54e3c8b commit 0e77a6e

File tree

6 files changed

+84
-36
lines changed

6 files changed

+84
-36
lines changed

openspeech/data/sampler.py

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,22 +20,24 @@
2020
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
2121
# SOFTWARE.
2222

23+
import os
2324
import numpy as np
24-
2525
from torch.utils.data import Sampler
2626

27+
from .audio.load import load_audio
28+
2729

28-
class BucketingSampler(Sampler):
30+
class RandomSampler(Sampler):
2931
r"""
30-
Samples batches assuming they are in order of size to batch similarly sized samples together.
32+
Implementation of a Random Sampler for sampling the dataset.
3133
3234
Args:
3335
data_source (torch.utils.data.Dataset): dataset to sample from
3436
batch_size (int): size of batch
3537
drop_last (bool): flat indication whether to drop last batch or not
3638
"""
3739
def __init__(self, data_source, batch_size: int = 32, drop_last: bool = False) -> None:
38-
super(BucketingSampler, self).__init__(data_source)
40+
super(RandomSampler, self).__init__(data_source)
3941
self.batch_size = batch_size
4042
self.data_source = data_source
4143
ids = list(range(0, len(data_source)))
@@ -52,3 +54,42 @@ def __len__(self):
5254

5355
def shuffle(self, epoch):
5456
np.random.shuffle(self.bins)
57+
58+
59+
class SmartBatchingSampler(Sampler):
60+
"""
61+
Batching with similar sequence length.
62+
63+
Args:
64+
data_source (torch.utils.data.Dataset): dataset to sample from
65+
batch_size (int): size of batch
66+
drop_last (bool): flat indication whether to drop last batch or not
67+
"""
68+
def __init__(self, data_source, batch_size: int = 32, drop_last: bool = False) -> None:
69+
super(SmartBatchingSampler, self).__init__(data_source)
70+
self.batch_size = batch_size
71+
self.data_source = data_source
72+
73+
audio_lengths = [self._get_audio_length(audio_path) for audio_path in data_source.audio_paths]
74+
audio_indices = [idx for idx in range(len(data_source.audio_paths))]
75+
76+
pack_by_length = list(zip(audio_lengths, audio_indices))
77+
sort_by_length = sorted(pack_by_length)
78+
audio_lengths, audio_indices = zip(*sort_by_length)
79+
80+
self.bins = [audio_indices[i:i + batch_size] for i in range(0, len(audio_indices), batch_size)]
81+
self.drop_last = drop_last
82+
83+
def __iter__(self):
84+
for ids in self.bins:
85+
np.random.shuffle(ids)
86+
yield ids
87+
88+
def _get_audio_length(self, audio_path):
89+
return len(load_audio(os.path.join(self.data_source.dataset_path, audio_path)))
90+
91+
def __len__(self):
92+
return len(self.bins)
93+
94+
def shuffle(self, epoch):
95+
np.random.shuffle(self.bins)

openspeech/dataclass/configurations.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,10 @@ class BaseTrainerConfigs(OpenspeechDataclass):
209209
default="binsearch", metadata={"help": "If set to True, will initially run a batch size finder trying to find "
210210
"the largest batch size that fits into memory."}
211211
)
212+
sampler: str = field(
213+
default="smart", metadata={"help": "smart: batching with similar sequence length."
214+
"else: random batch"}
215+
)
212216

213217

214218
@dataclass

openspeech/datasets/aishell/lit_data_module.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,12 @@
2727
import logging
2828
from omegaconf import DictConfig
2929
from typing import Optional, Tuple
30-
from torch.utils.data import DataLoader
3130

3231
from openspeech.data.audio.dataset import SpeechToTextDataset
3332
from openspeech.datasets import register_data_module
34-
from openspeech.data.sampler import BucketingSampler
33+
from openspeech.data.sampler import RandomSampler, SmartBatchingSampler
3534
from openspeech.data.audio.data_loader import AudioDataLoader
3635
from openspeech.tokenizers.tokenizer import Tokenizer
37-
from openspeech.tokenizers import TOKENIZER_REGISTRY
3836
from openspeech.datasets.aishell.preprocess import (
3937
generate_character_labels,
4038
generate_character_script,
@@ -158,24 +156,27 @@ def setup(self, stage: Optional[str] = None, tokenizer: Tokenizer = None):
158156
del_silence=self.configs.audio.del_silence if stage == 'train' else False,
159157
)
160158

161-
def train_dataloader(self) -> DataLoader:
162-
train_sampler = BucketingSampler(self.dataset['train'], batch_size=self.configs.trainer.batch_size)
159+
def train_dataloader(self) -> AudioDataLoader:
160+
sampler = SmartBatchingSampler if self.configs.trainer.sampler == 'smart' else RandomSampler
161+
train_sampler = sampler(data_source=self.dataset['train'], batch_size=self.configs.trainer.batch_size)
163162
return AudioDataLoader(
164163
dataset=self.dataset['train'],
165164
num_workers=self.configs.trainer.num_workers,
166165
batch_sampler=train_sampler,
167166
)
168167

169-
def val_dataloader(self) -> DataLoader:
170-
valid_sampler = BucketingSampler(self.dataset['valid'], batch_size=self.configs.trainer.batch_size)
168+
def val_dataloader(self) -> AudioDataLoader:
169+
sampler = SmartBatchingSampler if self.configs.trainer.sampler == 'smart' else RandomSampler
170+
valid_sampler = sampler(self.dataset['valid'], batch_size=self.configs.trainer.batch_size)
171171
return AudioDataLoader(
172172
dataset=self.dataset['valid'],
173173
num_workers=self.configs.trainer.num_workers,
174174
batch_sampler=valid_sampler,
175175
)
176176

177-
def test_dataloader(self) -> DataLoader:
178-
test_sampler = BucketingSampler(self.dataset['test'], batch_size=self.configs.trainer.batch_size)
177+
def test_dataloader(self) -> AudioDataLoader:
178+
sampler = SmartBatchingSampler if self.configs.trainer.sampler == 'smart' else RandomSampler
179+
test_sampler = sampler(self.dataset['test'], batch_size=self.configs.trainer.batch_size)
179180
return AudioDataLoader(
180181
dataset=self.dataset['test'],
181182
num_workers=self.configs.trainer.num_workers,

openspeech/datasets/ksponspeech/lit_data_module.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,15 @@
2525
import pytorch_lightning as pl
2626
from typing import Optional
2727
from omegaconf import DictConfig
28-
from openspeech.data.audio.dataset import SpeechToTextDataset
2928

29+
from openspeech.data.audio.dataset import SpeechToTextDataset
3030
from openspeech.datasets import register_data_module
31-
from openspeech.data.sampler import BucketingSampler
31+
from openspeech.data.sampler import RandomSampler, SmartBatchingSampler
3232
from openspeech.data.audio.data_loader import AudioDataLoader
3333
from openspeech.datasets.ksponspeech.preprocess.preprocess import preprocess, preprocess_test_data
3434
from openspeech.datasets.ksponspeech.preprocess.character import generate_character_script, generate_character_labels
3535
from openspeech.datasets.ksponspeech.preprocess.grapheme import sentence_to_grapheme
3636
from openspeech.datasets.ksponspeech.preprocess.subword import train_sentencepiece, sentence_to_subwords
37-
from openspeech.tokenizers import TOKENIZER_REGISTRY
3837
from openspeech.tokenizers.tokenizer import Tokenizer
3938

4039

@@ -49,6 +48,8 @@ class LightningKsponSpeechDataModule(pl.LightningDataModule):
4948
5049
Attributes:
5150
KSPONSPEECH_TRAIN_NUM (int): the number of KsponSpeech's train data.
51+
KSPONSPEECH_VALID_NUM (int): the number of KsponSpeech's validation data.
52+
KSPONSPEECH_TEST_NUM (int): the number of KsponSpeech's test data.
5253
5354
Args:
5455
configs (DictConfig): configuration set.
@@ -173,26 +174,26 @@ def setup(self, stage: Optional[str] = None, tokenizer: Tokenizer = None):
173174
)
174175

175176
def train_dataloader(self) -> AudioDataLoader:
176-
r""" Return data loader for training. """
177-
train_sampler = BucketingSampler(self.dataset['train'], batch_size=self.configs.trainer.batch_size)
177+
sampler = SmartBatchingSampler if self.configs.trainer.sampler == 'smart' else RandomSampler
178+
train_sampler = sampler(data_source=self.dataset['train'], batch_size=self.configs.trainer.batch_size)
178179
return AudioDataLoader(
179180
dataset=self.dataset['train'],
180181
num_workers=self.configs.trainer.num_workers,
181182
batch_sampler=train_sampler,
182183
)
183184

184185
def val_dataloader(self) -> AudioDataLoader:
185-
r""" Return data loader for validation. """
186-
valid_sampler = BucketingSampler(self.dataset['valid'], batch_size=self.configs.trainer.batch_size)
186+
sampler = SmartBatchingSampler if self.configs.trainer.sampler == 'smart' else RandomSampler
187+
valid_sampler = sampler(self.dataset['valid'], batch_size=self.configs.trainer.batch_size)
187188
return AudioDataLoader(
188189
dataset=self.dataset['valid'],
189190
num_workers=self.configs.trainer.num_workers,
190191
batch_sampler=valid_sampler,
191192
)
192193

193194
def test_dataloader(self) -> AudioDataLoader:
194-
r""" Return data loader for training. """
195-
test_sampler = BucketingSampler(self.dataset['test'], batch_size=self.configs.trainer.batch_size)
195+
sampler = SmartBatchingSampler if self.configs.trainer.sampler == 'smart' else RandomSampler
196+
test_sampler = sampler(self.dataset['test'], batch_size=self.configs.trainer.batch_size)
196197
return AudioDataLoader(
197198
dataset=self.dataset['test'],
198199
num_workers=self.configs.trainer.num_workers,

openspeech/datasets/language_model/lit_data_module.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from omegaconf import DictConfig
2828
from typing import Optional
2929

30-
from openspeech.data.sampler import BucketingSampler
30+
from openspeech.data.sampler import RandomSampler
3131
from openspeech.data.text.data_loader import TextDataLoader
3232
from openspeech.data.text.dataset import TextDataset
3333
from openspeech.datasets import register_data_module
@@ -78,7 +78,7 @@ def setup(self, stage: Optional[str] = None, tokenizer: Tokenizer = None):
7878
)
7979

8080
def train_dataloader(self) -> TextDataLoader:
81-
train_sampler = BucketingSampler(self.dataset['train'], batch_size=self.configs.trainer.batch_size)
81+
train_sampler = RandomSampler(self.dataset['train'], batch_size=self.configs.trainer.batch_size)
8282
return TextDataLoader(
8383
dataset=self.dataset['train'],
8484
num_workers=self.configs.trainer.num_workers,
@@ -87,7 +87,7 @@ def train_dataloader(self) -> TextDataLoader:
8787

8888
def val_dataloader(self) -> TextDataLoader:
8989
r""" Return data loader for validation. """
90-
valid_sampler = BucketingSampler(self.dataset['valid'], batch_size=self.configs.trainer.batch_size)
90+
valid_sampler = RandomSampler(self.dataset['valid'], batch_size=self.configs.trainer.batch_size)
9191
return TextDataLoader(
9292
dataset=self.dataset['valid'],
9393
num_workers=self.configs.trainer.num_workers,
@@ -96,7 +96,7 @@ def val_dataloader(self) -> TextDataLoader:
9696

9797
def test_dataloader(self) -> TextDataLoader:
9898
r""" Return data loader for training. """
99-
train_sampler = BucketingSampler(self.dataset['test'], batch_size=self.configs.trainer.batch_size)
99+
train_sampler = RandomSampler(self.dataset['test'], batch_size=self.configs.trainer.batch_size)
100100
return TextDataLoader(
101101
dataset=self.dataset['test'],
102102
num_workers=self.configs.trainer.num_workers,

openspeech/datasets/librispeech/lit_data_module.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,11 @@
2828
import pytorch_lightning as pl
2929
from typing import Tuple, Optional
3030
from omegaconf import DictConfig
31-
from openspeech.data.audio.dataset import SpeechToTextDataset
32-
from torch.utils.data import DataLoader
3331

32+
from openspeech.data.audio.dataset import SpeechToTextDataset
3433
from openspeech.datasets import register_data_module
35-
from openspeech.tokenizers import TOKENIZER_REGISTRY
3634
from openspeech.tokenizers.tokenizer import Tokenizer
37-
from openspeech.data.sampler import BucketingSampler
35+
from openspeech.data.sampler import RandomSampler, SmartBatchingSampler
3836
from openspeech.data.audio.data_loader import AudioDataLoader
3937

4038

@@ -188,24 +186,27 @@ def setup(self, stage: Optional[str] = None, tokenizer: Tokenizer = None) -> Non
188186
del_silence=self.configs.audio.del_silence if stage == 'train' else False,
189187
)
190188

191-
def train_dataloader(self) -> DataLoader:
192-
train_sampler = BucketingSampler(self.dataset['train'], batch_size=self.configs.trainer.batch_size)
189+
def train_dataloader(self) -> AudioDataLoader:
190+
sampler = SmartBatchingSampler if self.configs.trainer.sampler == 'smart' else RandomSampler
191+
train_sampler = sampler(data_source=self.dataset['train'], batch_size=self.configs.trainer.batch_size)
193192
return AudioDataLoader(
194193
dataset=self.dataset['train'],
195194
num_workers=self.configs.trainer.num_workers,
196195
batch_sampler=train_sampler,
197196
)
198197

199-
def val_dataloader(self) -> DataLoader:
200-
valid_sampler = BucketingSampler(self.dataset['valid'], batch_size=self.configs.trainer.batch_size)
198+
def val_dataloader(self) -> AudioDataLoader:
199+
sampler = SmartBatchingSampler if self.configs.trainer.sampler == 'smart' else RandomSampler
200+
valid_sampler = sampler(self.dataset['valid'], batch_size=self.configs.trainer.batch_size)
201201
return AudioDataLoader(
202202
dataset=self.dataset['valid'],
203203
num_workers=self.configs.trainer.num_workers,
204204
batch_sampler=valid_sampler,
205205
)
206206

207-
def test_dataloader(self) -> DataLoader:
208-
test_sampler = BucketingSampler(self.dataset['test'], batch_size=self.configs.trainer.batch_size)
207+
def test_dataloader(self) -> AudioDataLoader:
208+
sampler = SmartBatchingSampler if self.configs.trainer.sampler == 'smart' else RandomSampler
209+
test_sampler = sampler(self.dataset['test'], batch_size=self.configs.trainer.batch_size)
209210
return AudioDataLoader(
210211
dataset=self.dataset['test'],
211212
num_workers=self.configs.trainer.num_workers,

0 commit comments

Comments
 (0)