Skip to content

Fix shape mismatching in SuDORMRF's masknn#618

Merged
mpariente merged 1 commit intoasteroid-team:masterfrom
z-wony:sudormrf_upsample
Jun 7, 2022
Merged

Fix shape mismatching in SuDORMRF's masknn#618
mpariente merged 1 commit intoasteroid-team:masterfrom
z-wony:sudormrf_upsample

Conversation

@z-wony
Copy link
Contributor

@z-wony z-wony commented Jun 6, 2022

As the length of input audio tensor, sometimes U-ConvBlock's tensor shape becomes odd number.
In this case, shape miss-matching is occurred in process of downsampling and upsampling.

This commit trims additional last data in upsampling sequence.

@z-wony
Copy link
Contributor Author

z-wony commented Jun 6, 2022

Test Environment.

Test Model

  • Model: SuDORMRFImprovedNet
  • hparams: default settings
    • egs/librimix/SuDORMRFImprovedNet/run.sh
    • egs/librimix/SuDORMRFImprovedNet/local/conf.yml
    • masknet.upsampling_depth: 4

Test code

import sys 
sys.path.append('/home/jwkim/github/asteroid')
from asteroid.models import SuDORMRFImprovedNet
import torch

model = SuDORMRFImprovedNet.from_pretrained('best_model.pth')

mix = torch.rand(1, 32000 + 200)
ret = model.separate(mix)

Exception print

Traceback (most recent call last):
  File "/usr/lib/python3.8/runpy.py", line 192, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.8/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/home/jwkim/github/asteroid/separate.py", line 9, in <module>
    ret = model.separate(mix)
  File "/home/jwkim/github/asteroid/asteroid/models/base_models.py", line 64, in separate
    return separate.separate(self, *args, **kwargs)
  File "/home/jwkim/github/asteroid/asteroid/separate.py", line 84, in separate
    return torch_separate(model, wav, **kwargs)
  File "/home/jwkim/pTest/envs/ast/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
    return func(*args, **kwargs)
  File "/home/jwkim/github/asteroid/asteroid/separate.py", line 105, in torch_separate
    out_wavs = separate_func(wav, **kwargs)
  File "/home/jwkim/github/asteroid/asteroid/models/base_models.py", line 89, in forward_wav
    return self(wav, *args, **kwargs)
  File "/home/jwkim/pTest/envs/ast/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/jwkim/github/asteroid/asteroid/models/base_models.py", line 219, in forward
    est_masks = self.forward_masker(tf_rep)
  File "/home/jwkim/github/asteroid/asteroid/models/base_models.py", line 248, in forward_masker
    return self.masker(tf_rep)
  File "/home/jwkim/pTest/envs/ast/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/jwkim/github/asteroid/asteroid/masknn/convolutional.py", line 745, in forward
    x = self.sm(x)
  File "/home/jwkim/pTest/envs/ast/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/jwkim/pTest/envs/ast/lib/python3.8/site-packages/torch/nn/modules/container.py", line 139, in forward
    input = module(input)
  File "/home/jwkim/pTest/envs/ast/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/jwkim/github/asteroid/asteroid/masknn/convolutional.py", line 879, in forward
    output[-1] = output[-1] + resampled_out_k
RuntimeError: The size of tensor a (403) must match the size of tensor b (404) at non-singleton dimension 2

Debugging print

  • I printed shape of UConvBlock between downsampling and upsampling.
  • asteroid/masknn/convolutional.py:875
        # Do the downsampling process from the previous level
        for k in range(1, self.depth):
            out_k = self.spp_dw[k](output[-1])
            output.append(out_k)

        #  DEBUG PRINT
        for d in range(self.depth):
            print(f'depth: {d}, shape: {output[d].shape}')

        # Gather them now in reverse order
        for _ in range(self.depth - 1): 
            resampled_out_k = self.upsampler(output.pop(-1))
            output[-1] = output[-1] + resampled_out_k
  • print result
depth: 0, shape: torch.Size([1, 512, 1612])
depth: 1, shape: torch.Size([1, 512, 806])
depth: 2, shape: torch.Size([1, 512, 403])
depth: 3, shape: torch.Size([1, 512, 202])

Cause of exception

In below code, 'depth 3' ([1, 512, 202]) is upsampled to [1, 512, 404] (assigned to resampled_out_k).
But output[-1]'s shape is [1, 512, 403].
So, size miss matching exception is occurred.

        # Gather them now in reverse order
        for _ in range(self.depth - 1): 
            resampled_out_k = self.upsampler(output.pop(-1))
            output[-1] = output[-1] + resampled_out_k

@mpariente mpariente merged commit 6dc1c6a into asteroid-team:master Jun 7, 2022
@mpariente
Copy link
Collaborator

Thanks for the PR and the great detail in the explanation of the problem !

The CI failures are not due to your code, so it's merged

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants