diff --git a/asteroid/models/base_models.py b/asteroid/models/base_models.py index a4bf1ddad..eff636430 100644 --- a/asteroid/models/base_models.py +++ b/asteroid/models/base_models.py @@ -5,7 +5,7 @@ from .. import separate from ..masknn import activations from ..utils.torch_utils import pad_x_to_y, script_if_tracing, jitable_shape -from ..utils.hub_utils import cached_download +from ..utils.hub_utils import cached_download, SR_HASHTABLE from ..utils.deprecation_utils import is_overridden, mark_deprecated, VisibleDeprecationWarning @@ -37,8 +37,16 @@ class BaseModel(torch.nn.Module): If None, no checks will be performed. """ - def __init__(self, sample_rate: float = 8000.0, n_channels: Optional[int] = 1): + def __init__(self, sample_rate: float = None, n_channels: Optional[int] = 1): super().__init__() + if sample_rate is None: + sample_rate = 8000.0 + warnings.warn( + "The argument `sample_rate` of `BaseModel` will be required in the future. " + "It is no longer a keyword argument. This will raise an error in future release. " + "Defaults to 8000.0", + VisibleDeprecationWarning, + ) self.__sample_rate = sample_rate self.n_channels = n_channels @@ -144,20 +152,12 @@ def from_pretrained(cls, pretrained_model_conf_or_path, *args, **kwargs): "model_args`. Found only: {}".format(conf.keys()) ) conf["model_args"].update(kwargs) # kwargs overwrite config. - if "sample_rate" not in conf["model_args"]: - # Try retrieving from pretrained models - from ..utils.hub_utils import SR_HASHTABLE - - sr = None - if isinstance(pretrained_model_conf_or_path, str): - sr = SR_HASHTABLE.get(pretrained_model_conf_or_path, None) - if sr is None: - raise RuntimeError( - "Couldn't load pretrained model without sampling rate. You can either pass " - "`sample_rate` to the `from_pretrained` method or edit your model to include " - "the `sample_rate` key, or use `asteroid-register-sr model sample_rate` CLI." - ) - conf["model_args"]["sample_rate"] = sr + if "sample_rate" not in conf["model_args"] and isinstance( + pretrained_model_conf_or_path, str + ): + conf["model_args"]["sample_rate"] = SR_HASHTABLE.get( + pretrained_model_conf_or_path, None + ) # Attempt to find the model and instantiate it. try: model_class = get(conf["model_name"]) diff --git a/tests/models/models_test.py b/tests/models/models_test.py index a5d4160bf..84f1243cc 100644 --- a/tests/models/models_test.py +++ b/tests/models/models_test.py @@ -20,7 +20,7 @@ SuDORMRFNet, ) from asteroid.models.base_models import BaseModel - +from asteroid.utils.deprecation_utils import VisibleDeprecationWarning HF_EXAMPLE_MODEL_IDENTIFER = "julien-c/DPRNNTasNet-ks16_WHAM_sepclean" # An actual model hosted on huggingface.co @@ -32,6 +32,11 @@ def test_set_sample_rate_raises_warning(): model.sample_rate = 16000.0 +def test_no_sample_rate_raises_warning(): + with pytest.warns(VisibleDeprecationWarning): + BaseModel() + + def test_multichannel_model_loading(): class MCModel(BaseModel): def __init__(self, sample_rate=8000.0, n_channels=2): @@ -216,10 +221,9 @@ def _default_test_model(model, input_samples=801): reconstructed_model = model.__class__.from_pretrained(model_conf) assert_allclose(model(test_input), reconstructed_model(test_input)) - # Make + # Load with and without SR sr = model_conf["model_args"].pop("sample_rate") - with pytest.raises(RuntimeError): - reconstructed_model = model.__class__.from_pretrained(model_conf) + reconstructed_model_nosr = model.__class__.from_pretrained(model_conf) reconstructed_model = model.__class__.from_pretrained(model_conf, sample_rate=sr)