Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 16 additions & 16 deletions asteroid/models/base_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .. import separate
from ..masknn import activations
from ..utils.torch_utils import pad_x_to_y, script_if_tracing, jitable_shape
from ..utils.hub_utils import cached_download
from ..utils.hub_utils import cached_download, SR_HASHTABLE
from ..utils.deprecation_utils import is_overridden, mark_deprecated, VisibleDeprecationWarning


Expand Down Expand Up @@ -37,8 +37,16 @@ class BaseModel(torch.nn.Module):
If None, no checks will be performed.
"""

def __init__(self, sample_rate: float = 8000.0, n_channels: Optional[int] = 1):
def __init__(self, sample_rate: float = None, n_channels: Optional[int] = 1):
super().__init__()
if sample_rate is None:
sample_rate = 8000.0
warnings.warn(
"The argument `sample_rate` of `BaseModel` will be required in the future. "
"It is no longer a keyword argument. This will raise an error in future release. "
"Defaults to 8000.0",
VisibleDeprecationWarning,
)
self.__sample_rate = sample_rate
self.n_channels = n_channels

Expand Down Expand Up @@ -144,20 +152,12 @@ def from_pretrained(cls, pretrained_model_conf_or_path, *args, **kwargs):
"model_args`. Found only: {}".format(conf.keys())
)
conf["model_args"].update(kwargs) # kwargs overwrite config.
if "sample_rate" not in conf["model_args"]:
# Try retrieving from pretrained models
from ..utils.hub_utils import SR_HASHTABLE

sr = None
if isinstance(pretrained_model_conf_or_path, str):
sr = SR_HASHTABLE.get(pretrained_model_conf_or_path, None)
if sr is None:
raise RuntimeError(
"Couldn't load pretrained model without sampling rate. You can either pass "
"`sample_rate` to the `from_pretrained` method or edit your model to include "
"the `sample_rate` key, or use `asteroid-register-sr model sample_rate` CLI."
)
conf["model_args"]["sample_rate"] = sr
if "sample_rate" not in conf["model_args"] and isinstance(
pretrained_model_conf_or_path, str
):
conf["model_args"]["sample_rate"] = SR_HASHTABLE.get(
pretrained_model_conf_or_path, None
)
# Attempt to find the model and instantiate it.
try:
model_class = get(conf["model_name"])
Expand Down
12 changes: 8 additions & 4 deletions tests/models/models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
SuDORMRFNet,
)
from asteroid.models.base_models import BaseModel

from asteroid.utils.deprecation_utils import VisibleDeprecationWarning

HF_EXAMPLE_MODEL_IDENTIFER = "julien-c/DPRNNTasNet-ks16_WHAM_sepclean"
# An actual model hosted on huggingface.co
Expand All @@ -32,6 +32,11 @@ def test_set_sample_rate_raises_warning():
model.sample_rate = 16000.0


def test_no_sample_rate_raises_warning():
with pytest.warns(VisibleDeprecationWarning):
BaseModel()


def test_multichannel_model_loading():
class MCModel(BaseModel):
def __init__(self, sample_rate=8000.0, n_channels=2):
Expand Down Expand Up @@ -216,10 +221,9 @@ def _default_test_model(model, input_samples=801):
reconstructed_model = model.__class__.from_pretrained(model_conf)
assert_allclose(model(test_input), reconstructed_model(test_input))

# Make
# Load with and without SR
sr = model_conf["model_args"].pop("sample_rate")
with pytest.raises(RuntimeError):
reconstructed_model = model.__class__.from_pretrained(model_conf)
reconstructed_model_nosr = model.__class__.from_pretrained(model_conf)
reconstructed_model = model.__class__.from_pretrained(model_conf, sample_rate=sr)


Expand Down