Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions asteroid/models/dccrnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,15 @@ class DCCRNet(BaseDCUNet):

masknet_class = DCCRMaskNet

def __init__(self, *args, stft_kernel_size=512, **masknet_kwargs):
masknet_kwargs.setdefault("n_freqs", stft_kernel_size // 2)
def __init__(
self, *args, stft_n_filters=512, stft_kernel_size=400, stft_stride=100, **masknet_kwargs
):
masknet_kwargs.setdefault("n_freqs", stft_n_filters // 2)
super().__init__(
*args,
stft_n_filters=stft_n_filters,
stft_kernel_size=stft_kernel_size,
stft_stride=stft_stride,
**masknet_kwargs,
)

Expand Down
15 changes: 10 additions & 5 deletions asteroid/models/dcunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ class BaseDCUNet(BaseEncoderMaskerDecoder):

Args:
architecture (str): The architecture to use. Overriden by subclasses.
stft_kernel_size (int): STFT frame length to use
stft_n_filters (int) Number of filters for the STFT.
stft_kernel_size (int): STFT frame length to use.
stft_stride (int, optional): STFT hop length to use.
sample_rate (float): Sampling rate of the model.
masknet_kwargs (optional): Passed to the masknet constructor.
Expand All @@ -20,20 +21,22 @@ class BaseDCUNet(BaseEncoderMaskerDecoder):
def __init__(
self,
architecture,
stft_kernel_size=512,
stft_stride=None,
stft_n_filters=1024,
stft_kernel_size=1024,
stft_stride=256,
sample_rate=16000.0,
**masknet_kwargs,
):
self.architecture = architecture
self.stft_n_filters = stft_n_filters
self.stft_kernel_size = stft_kernel_size
self.stft_stride = stft_stride
self.masknet_kwargs = masknet_kwargs

encoder, decoder = make_enc_dec(
"stft",
n_filters=stft_n_filters,
kernel_size=stft_kernel_size,
n_filters=stft_kernel_size,
stride=stft_stride,
sample_rate=sample_rate,
)
Expand All @@ -52,6 +55,7 @@ def get_model_args(self):
"""Arguments needed to re-instantiate the model."""
model_args = {
"architecture": self.architecture,
"stft_n_filters": self.stft_n_filters,
"stft_kernel_size": self.stft_kernel_size,
"stft_stride": self.stft_stride,
"sample_rate": self.sample_rate,
Expand All @@ -66,7 +70,8 @@ class DCUNet(BaseDCUNet):
Args:
architecture (str): The architecture to use, any of
"DCUNet-10", "DCUNet-16", "DCUNet-20", "Large-DCUNet-20".
stft_kernel_size (int): STFT frame length to use
stft_n_filters (int) Number of filters for the STFT.
stft_kernel_size (int): STFT frame length to use.
stft_stride (int, optional): STFT hop length to use.
sample_rate (float): Sampling rate of the model.
masknet_kwargs (optional): Passed to :class:`DCUMaskNet`
Expand Down
3 changes: 2 additions & 1 deletion tests/jit/jit_models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ def small_model_params():
"stride": 16,
},
DCCRNet.__name__: {
"stft_n_filters": 512,
"stft_kernel_size": 256,
"stft_stride": 100,
"architecture": "mini",
},
DeMask.__name__: {
Expand Down Expand Up @@ -123,7 +125,6 @@ def test_enhancement_model(small_model_params, model_def, test_data):
# Random input uniformly distributed in [-1, 1]
inputs = ((torch.rand(1, 2500, device=device) - 0.5) * 2,)
traced = torch.jit.trace(model, inputs)

assert_consistency(model=model, traced=traced, tensor=test_data.to(device))


Expand Down
12 changes: 8 additions & 4 deletions tests/models/models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,11 @@ def test_dptnet(fb):


def test_dcunet():
_, istft = make_enc_dec("stft", 512, 512)
input_samples = istft(torch.zeros((514, 17))).shape[0]
n_fft = 1024
_, istft = make_enc_dec(
"stft", n_filters=n_fft, kernel_size=1024, stride=256, sample_rate=16000
)
input_samples = istft(torch.zeros((n_fft + 2, 17))).shape[0]
_default_test_model(DCUNet("DCUNet-10"), input_samples=input_samples)
_default_test_model(DCUNet("DCUNet-10", n_src=2), input_samples=input_samples)

Expand All @@ -173,8 +176,9 @@ def test_dcunet():


def test_dccrnet():
_, istft = make_enc_dec("stft", 512, 512)
input_samples = istft(torch.zeros((514, 16))).shape[0]
n_fft = 512
_, istft = make_enc_dec("stft", n_filters=n_fft, kernel_size=400, stride=100, sample_rate=16000)
input_samples = istft(torch.zeros((n_fft + 2, 16))).shape[0]
_default_test_model(DCCRNet("DCCRN-CL"), input_samples=input_samples)
_default_test_model(DCCRNet("DCCRN-CL", n_src=2), input_samples=input_samples)

Expand Down