Skip to content

Commit b13741e

Browse files
julien-cmpariente
andauthored
[hub] Support for huggingface model hub 🎉 (#377)
Co-authored-by: Pariente Manuel <pariente.mnl@gmail.com>
1 parent 77d1b74 commit b13741e

File tree

4 files changed

+295
-7
lines changed

4 files changed

+295
-7
lines changed

asteroid/utils/hub_utils.py

Lines changed: 248 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,19 @@
1+
import fnmatch
2+
import io
3+
import json
14
import os
2-
from torch import hub
5+
import sys
6+
import tempfile
7+
from contextlib import contextmanager
8+
from functools import partial
39
from hashlib import sha256
10+
from typing import BinaryIO, Dict, Optional, Union
11+
from urllib.parse import urlparse
12+
13+
import requests
14+
import torch
15+
from filelock import FileLock
16+
from torch import hub
417

518

619
CACHE_DIR = os.getenv(
@@ -23,14 +36,18 @@
2336

2437
SR_HASHTABLE = {k: 8000.0 if not "DeMask" in k else 16000.0 for k in MODELS_URLS_HASHTABLE}
2538

39+
HF_WEIGHTS_NAME = "pytorch_model.bin"
40+
HUGGINGFACE_CO_PREFIX = "https://huggingface.co/{model_id}/resolve/{revision}/{filename}"
41+
2642

2743
def cached_download(filename_or_url):
28-
"""Download from URL with torch.hub and cache the result in ASTEROID_CACHE.
44+
"""Download from URL and cache the result in ASTEROID_CACHE.
2945
3046
Args:
3147
filename_or_url (str): Name of a model as named on the Zenodo Community
32-
page (ex: ``"mpariente/ConvTasNet_WHAM!_sepclean"``), or an URL to a model
33-
file (ex: ``"https://zenodo.org/.../model.pth"``), or a filename
48+
page (ex: ``"mpariente/ConvTasNet_WHAM!_sepclean"``), or model id from
49+
the Hugging Face model hub (ex: ``"julien-c/DPRNNTasNet-ks16_WHAM_sepclean"``),
50+
or a URL to a model file (ex: ``"https://zenodo.org/.../model.pth"``), or a filename
3451
that exists locally (ex: ``"local/tmp_model.pth"``)
3552
3653
Returns:
@@ -39,11 +56,22 @@ def cached_download(filename_or_url):
3956
if os.path.isfile(filename_or_url):
4057
return filename_or_url
4158

42-
if filename_or_url in MODELS_URLS_HASHTABLE:
59+
if urlparse(filename_or_url).scheme in ("http", "https"):
60+
url = filename_or_url
61+
elif filename_or_url in MODELS_URLS_HASHTABLE:
4362
url = MODELS_URLS_HASHTABLE[filename_or_url]
4463
else:
45-
# Give a chance to direct URL, torch.hub will handle exceptions
46-
url = filename_or_url
64+
# Finally, let's try to find it on Hugging Face model hub
65+
# e.g. julien-c/DPRNNTasNet-ks16_WHAM_sepclean is a valid model id
66+
# and julien-c/DPRNNTasNet-ks16_WHAM_sepclean@main supports specifying a commit/branch/tag.
67+
if "@" in filename_or_url:
68+
model_id = filename_or_url.split("@")[0]
69+
revision = filename_or_url.split("@")[1]
70+
else:
71+
model_id = filename_or_url
72+
revision = None
73+
url = hf_bucket_url(model_id=model_id, filename=HF_WEIGHTS_NAME, revision=revision)
74+
return hf_get_from_cache(url, cache_dir=get_cache_dir())
4775
cached_filename = url_to_filename(url)
4876
cached_dir = os.path.join(get_cache_dir(), cached_filename)
4977
cached_path = os.path.join(cached_dir, "model.pth")
@@ -68,3 +96,216 @@ def url_to_filename(url):
6896
def get_cache_dir():
6997
os.makedirs(CACHE_DIR, exist_ok=True)
7098
return CACHE_DIR
99+
100+
101+
def hf_bucket_url(
102+
model_id: str, filename: str, subfolder: Optional[str] = None, revision: Optional[str] = None
103+
) -> str:
104+
"""
105+
Resolve a model identifier, a file name, and an optional revision id, to a huggingface.co-hosted url, redirecting
106+
to Cloudfront (a Content Delivery Network, or CDN) for large files.
107+
108+
Cloudfront is replicated over the globe so downloads are way faster for the end user (and it also lowers our
109+
bandwidth costs).
110+
111+
Cloudfront aggressively caches files by default (default TTL is 24 hours), however this is not an issue here
112+
because we migrated to a git-based versioning system on huggingface.co, so we now store the files on S3/Cloudfront
113+
in a content-addressable way (i.e., the file name is its hash). Using content-addressable filenames means cache
114+
can't ever be stale.
115+
116+
In terms of client-side caching from this library, we base our caching on the objects' ETag. An object' ETag is:
117+
its sha1 if stored in git, or its sha256 if stored in git-lfs.
118+
"""
119+
if subfolder is not None:
120+
filename = f"{subfolder}/{filename}"
121+
122+
if revision is None:
123+
revision = "main"
124+
return HUGGINGFACE_CO_PREFIX.format(model_id=model_id, revision=revision, filename=filename)
125+
126+
127+
def hf_url_to_filename(url: str, etag: Optional[str] = None) -> str:
128+
"""
129+
Convert `url` into a hashed filename in a repeatable way. If `etag` is specified, append its hash to the url's,
130+
delimited by a period.
131+
"""
132+
url_bytes = url.encode("utf-8")
133+
filename = sha256(url_bytes).hexdigest()
134+
135+
if etag:
136+
etag_bytes = etag.encode("utf-8")
137+
filename += "." + sha256(etag_bytes).hexdigest()
138+
139+
return filename
140+
141+
142+
def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
143+
"""
144+
Formats a user-agent string with basic info about a request.
145+
"""
146+
from .. import __version__ as asteroid_version # Avoid circular imports
147+
148+
ua = "asteroid/{}; python/{}".format(asteroid_version, sys.version.split()[0])
149+
ua += "; torch/{}".format(torch.__version__)
150+
if isinstance(user_agent, dict):
151+
ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items())
152+
elif isinstance(user_agent, str):
153+
ua += "; " + user_agent
154+
return ua
155+
156+
157+
def http_get(
158+
url: str,
159+
temp_file: BinaryIO,
160+
proxies=None,
161+
resume_size=0,
162+
user_agent: Union[Dict, str, None] = None,
163+
):
164+
"""
165+
Donwload remote file. Do not gobble up errors.
166+
"""
167+
headers = {"user-agent": http_user_agent(user_agent)}
168+
if resume_size > 0:
169+
headers["Range"] = "bytes=%d-" % (resume_size,)
170+
r = requests.get(url, stream=True, proxies=proxies, headers=headers)
171+
r.raise_for_status()
172+
for chunk in r.iter_content(chunk_size=1024):
173+
if chunk: # filter out keep-alive new chunks
174+
temp_file.write(chunk)
175+
176+
177+
def hf_get_from_cache(
178+
url: str,
179+
cache_dir: str,
180+
force_download=False,
181+
proxies=None,
182+
etag_timeout=10,
183+
resume_download=False,
184+
user_agent: Union[Dict, str, None] = None,
185+
local_files_only=False,
186+
) -> Optional[str]: # pragma: no cover
187+
"""
188+
Given a URL, look for the corresponding file in the local cache. If it's not there, download it. Then return the
189+
path to the cached file.
190+
191+
Return:
192+
Local path (string) of file or if networking is off, last version of file cached on disk.
193+
194+
Raises:
195+
In case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
196+
"""
197+
198+
os.makedirs(cache_dir, exist_ok=True)
199+
200+
url_to_download = url
201+
etag = None
202+
if not local_files_only:
203+
try:
204+
headers = {"user-agent": http_user_agent(user_agent)}
205+
r = requests.head(
206+
url, headers=headers, allow_redirects=False, proxies=proxies, timeout=etag_timeout
207+
)
208+
r.raise_for_status()
209+
etag = r.headers.get("X-Linked-Etag") or r.headers.get("ETag")
210+
# We favor a custom header indicating the etag of the linked resource, and
211+
# we fallback to the regular etag header.
212+
# If we don't have any of those, raise an error.
213+
if etag is None:
214+
raise OSError(
215+
"Distant resource does not have an ETag, we won't be able to reliably ensure reproducibility."
216+
)
217+
# In case of a redirect,
218+
# save an extra redirect on the request.get call,
219+
# and ensure we download the exact atomic version even if it changed
220+
# between the HEAD and the GET (unlikely, but hey).
221+
if 300 <= r.status_code <= 399:
222+
url_to_download = r.headers["Location"]
223+
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout):
224+
# etag is already None
225+
pass
226+
227+
filename = hf_url_to_filename(url, etag)
228+
229+
# get cache path to put the file
230+
cache_path = os.path.join(cache_dir, filename)
231+
232+
# etag is None == we don't have a connection or we passed local_files_only.
233+
# try to get the last downloaded one
234+
if etag is None:
235+
if os.path.exists(cache_path):
236+
return cache_path
237+
else:
238+
matching_files = [
239+
file
240+
for file in fnmatch.filter(os.listdir(cache_dir), filename.split(".")[0] + ".*")
241+
if not file.endswith(".json") and not file.endswith(".lock")
242+
]
243+
if len(matching_files) > 0:
244+
return os.path.join(cache_dir, matching_files[-1])
245+
else:
246+
# If files cannot be found and local_files_only=True,
247+
# the models might've been found if local_files_only=False
248+
# Notify the user about that
249+
if local_files_only:
250+
raise ValueError(
251+
"Cannot find the requested files in the cached path and outgoing traffic has been"
252+
" disabled. To enable model look-ups and downloads online, set 'local_files_only'"
253+
" to False."
254+
)
255+
else:
256+
raise ValueError(
257+
"Connection error, and we cannot find the requested files in the cached path."
258+
" Please try again or make sure your Internet connection is on."
259+
)
260+
261+
# From now on, etag is not None.
262+
if os.path.exists(cache_path) and not force_download:
263+
return cache_path
264+
265+
# Prevent parallel downloads of the same file with a lock.
266+
lock_path = cache_path + ".lock"
267+
with FileLock(lock_path):
268+
269+
# If the download just completed while the lock was activated.
270+
if os.path.exists(cache_path) and not force_download:
271+
# Even if returning early like here, the lock will be released.
272+
return cache_path
273+
274+
if resume_download:
275+
incomplete_path = cache_path + ".incomplete"
276+
277+
@contextmanager
278+
def _resumable_file_manager() -> "io.BufferedWriter":
279+
with open(incomplete_path, "ab") as f:
280+
yield f
281+
282+
temp_file_manager = _resumable_file_manager
283+
if os.path.exists(incomplete_path):
284+
resume_size = os.stat(incomplete_path).st_size
285+
else:
286+
resume_size = 0
287+
else:
288+
temp_file_manager = partial(
289+
tempfile.NamedTemporaryFile, mode="wb", dir=cache_dir, delete=False
290+
)
291+
resume_size = 0
292+
293+
# Download to temporary file, then copy to cache dir once finished.
294+
# Otherwise you get corrupt cache entries if the download gets interrupted.
295+
with temp_file_manager() as temp_file:
296+
http_get(
297+
url_to_download,
298+
temp_file,
299+
proxies=proxies,
300+
resume_size=resume_size,
301+
user_agent=user_agent,
302+
)
303+
304+
os.replace(temp_file.name, cache_path)
305+
306+
meta = {"url": url, "etag": etag}
307+
meta_path = cache_path + ".json"
308+
with open(meta_path, "w") as meta_file:
309+
json.dump(meta, meta_file)
310+
311+
return cache_path

setup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
"torch_stoi",
3030
"asteroid-filterbanks",
3131
"librosa",
32+
"filelock",
33+
"requests",
3234
],
3335
extras_require={
3436
"tests": ["pytest"],

tests/models/models_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@
2222
from asteroid.models.base_models import BaseModel
2323

2424

25+
HF_EXAMPLE_MODEL_IDENTIFER = "julien-c/DPRNNTasNet-ks16_WHAM_sepclean"
26+
# An actual model hosted on huggingface.co
27+
28+
2529
def test_set_sample_rate_raises_warning():
2630
model = BaseModel(sample_rate=8000.0)
2731
with pytest.warns(UserWarning):
@@ -90,6 +94,11 @@ def test_dprnntasnet_sep():
9094
assert isinstance(out, np.ndarray)
9195

9296

97+
def test_dprnntasnet_sep_from_hf():
98+
model = DPRNNTasNet.from_pretrained(HF_EXAMPLE_MODEL_IDENTIFER)
99+
assert isinstance(model, DPRNNTasNet)
100+
101+
93102
@pytest.mark.parametrize("fb", ["free", "stft", "analytic_free", "param_sinc"])
94103
def test_save_and_load_dprnn(fb):
95104
_default_test_model(

tests/utils/hub_utils_test.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,46 @@
22
from asteroid.utils import hub_utils
33

44

5+
HF_EXAMPLE_MODEL_IDENTIFER = "julien-c/DPRNNTasNet-ks16_WHAM_sepclean"
6+
# An actual model hosted on huggingface.co
7+
8+
REVISION_ID_ONE_SPECIFIC_COMMIT = "8ab5ef18ef2eda141dd11a5d037a8bede7804ce4"
9+
# One particular commit (not the top of `main`)
10+
11+
512
def test_download():
613
# We download
714
path1 = hub_utils.cached_download("mpariente/ConvTasNet_WHAM!_sepclean")
815
assert os.path.isfile(path1)
916
# We use cache
1017
path2 = hub_utils.cached_download("mpariente/ConvTasNet_WHAM!_sepclean")
1118
assert path1 == path2
19+
20+
21+
def test_hf_download():
22+
# We download
23+
path1 = hub_utils.cached_download(HF_EXAMPLE_MODEL_IDENTIFER)
24+
assert os.path.isfile(path1)
25+
# We use cache
26+
path2 = hub_utils.cached_download(HF_EXAMPLE_MODEL_IDENTIFER)
27+
assert path1 == path2
28+
# However if specifying a particular commit,
29+
# file will be different.
30+
path3 = hub_utils.cached_download(
31+
f"{HF_EXAMPLE_MODEL_IDENTIFER}@{REVISION_ID_ONE_SPECIFIC_COMMIT}"
32+
)
33+
assert path3 != path1
34+
35+
36+
def test_http_user_agent():
37+
ua1 = hub_utils.http_user_agent("foobar/1")
38+
assert "foobar/1" in ua1
39+
ua2 = hub_utils.http_user_agent({"foobar": 1})
40+
assert ua1 == ua2
41+
42+
43+
def test_hf_bucket_url():
44+
url = hub_utils.hf_bucket_url(
45+
model_id="julien-c/foo", filename="model.bin", subfolder="folder", revision="v1.0.0"
46+
)
47+
assert isinstance(url, str)

0 commit comments

Comments
 (0)