1+ import fnmatch
2+ import io
3+ import json
14import os
2- from torch import hub
5+ import sys
6+ import tempfile
7+ from contextlib import contextmanager
8+ from functools import partial
39from hashlib import sha256
10+ from typing import BinaryIO , Dict , Optional , Union
11+ from urllib .parse import urlparse
12+
13+ import requests
14+ import torch
15+ from filelock import FileLock
16+ from torch import hub
417
518
619CACHE_DIR = os .getenv (
2336
2437SR_HASHTABLE = {k : 8000.0 if not "DeMask" in k else 16000.0 for k in MODELS_URLS_HASHTABLE }
2538
39+ HF_WEIGHTS_NAME = "pytorch_model.bin"
40+ HUGGINGFACE_CO_PREFIX = "https://huggingface.co/{model_id}/resolve/{revision}/{filename}"
41+
2642
2743def cached_download (filename_or_url ):
28- """Download from URL with torch.hub and cache the result in ASTEROID_CACHE.
44+ """Download from URL and cache the result in ASTEROID_CACHE.
2945
3046 Args:
3147 filename_or_url (str): Name of a model as named on the Zenodo Community
32- page (ex: ``"mpariente/ConvTasNet_WHAM!_sepclean"``), or an URL to a model
33- file (ex: ``"https://zenodo.org/.../model.pth"``), or a filename
48+ page (ex: ``"mpariente/ConvTasNet_WHAM!_sepclean"``), or model id from
49+ the Hugging Face model hub (ex: ``"julien-c/DPRNNTasNet-ks16_WHAM_sepclean"``),
50+ or a URL to a model file (ex: ``"https://zenodo.org/.../model.pth"``), or a filename
3451 that exists locally (ex: ``"local/tmp_model.pth"``)
3552
3653 Returns:
@@ -39,11 +56,22 @@ def cached_download(filename_or_url):
3956 if os .path .isfile (filename_or_url ):
4057 return filename_or_url
4158
42- if filename_or_url in MODELS_URLS_HASHTABLE :
59+ if urlparse (filename_or_url ).scheme in ("http" , "https" ):
60+ url = filename_or_url
61+ elif filename_or_url in MODELS_URLS_HASHTABLE :
4362 url = MODELS_URLS_HASHTABLE [filename_or_url ]
4463 else :
45- # Give a chance to direct URL, torch.hub will handle exceptions
46- url = filename_or_url
64+ # Finally, let's try to find it on Hugging Face model hub
65+ # e.g. julien-c/DPRNNTasNet-ks16_WHAM_sepclean is a valid model id
66+ # and julien-c/DPRNNTasNet-ks16_WHAM_sepclean@main supports specifying a commit/branch/tag.
67+ if "@" in filename_or_url :
68+ model_id = filename_or_url .split ("@" )[0 ]
69+ revision = filename_or_url .split ("@" )[1 ]
70+ else :
71+ model_id = filename_or_url
72+ revision = None
73+ url = hf_bucket_url (model_id = model_id , filename = HF_WEIGHTS_NAME , revision = revision )
74+ return hf_get_from_cache (url , cache_dir = get_cache_dir ())
4775 cached_filename = url_to_filename (url )
4876 cached_dir = os .path .join (get_cache_dir (), cached_filename )
4977 cached_path = os .path .join (cached_dir , "model.pth" )
@@ -68,3 +96,216 @@ def url_to_filename(url):
6896def get_cache_dir ():
6997 os .makedirs (CACHE_DIR , exist_ok = True )
7098 return CACHE_DIR
99+
100+
101+ def hf_bucket_url (
102+ model_id : str , filename : str , subfolder : Optional [str ] = None , revision : Optional [str ] = None
103+ ) -> str :
104+ """
105+ Resolve a model identifier, a file name, and an optional revision id, to a huggingface.co-hosted url, redirecting
106+ to Cloudfront (a Content Delivery Network, or CDN) for large files.
107+
108+ Cloudfront is replicated over the globe so downloads are way faster for the end user (and it also lowers our
109+ bandwidth costs).
110+
111+ Cloudfront aggressively caches files by default (default TTL is 24 hours), however this is not an issue here
112+ because we migrated to a git-based versioning system on huggingface.co, so we now store the files on S3/Cloudfront
113+ in a content-addressable way (i.e., the file name is its hash). Using content-addressable filenames means cache
114+ can't ever be stale.
115+
116+ In terms of client-side caching from this library, we base our caching on the objects' ETag. An object' ETag is:
117+ its sha1 if stored in git, or its sha256 if stored in git-lfs.
118+ """
119+ if subfolder is not None :
120+ filename = f"{ subfolder } /{ filename } "
121+
122+ if revision is None :
123+ revision = "main"
124+ return HUGGINGFACE_CO_PREFIX .format (model_id = model_id , revision = revision , filename = filename )
125+
126+
127+ def hf_url_to_filename (url : str , etag : Optional [str ] = None ) -> str :
128+ """
129+ Convert `url` into a hashed filename in a repeatable way. If `etag` is specified, append its hash to the url's,
130+ delimited by a period.
131+ """
132+ url_bytes = url .encode ("utf-8" )
133+ filename = sha256 (url_bytes ).hexdigest ()
134+
135+ if etag :
136+ etag_bytes = etag .encode ("utf-8" )
137+ filename += "." + sha256 (etag_bytes ).hexdigest ()
138+
139+ return filename
140+
141+
142+ def http_user_agent (user_agent : Union [Dict , str , None ] = None ) -> str :
143+ """
144+ Formats a user-agent string with basic info about a request.
145+ """
146+ from .. import __version__ as asteroid_version # Avoid circular imports
147+
148+ ua = "asteroid/{}; python/{}" .format (asteroid_version , sys .version .split ()[0 ])
149+ ua += "; torch/{}" .format (torch .__version__ )
150+ if isinstance (user_agent , dict ):
151+ ua += "; " + "; " .join ("{}/{}" .format (k , v ) for k , v in user_agent .items ())
152+ elif isinstance (user_agent , str ):
153+ ua += "; " + user_agent
154+ return ua
155+
156+
157+ def http_get (
158+ url : str ,
159+ temp_file : BinaryIO ,
160+ proxies = None ,
161+ resume_size = 0 ,
162+ user_agent : Union [Dict , str , None ] = None ,
163+ ):
164+ """
165+ Donwload remote file. Do not gobble up errors.
166+ """
167+ headers = {"user-agent" : http_user_agent (user_agent )}
168+ if resume_size > 0 :
169+ headers ["Range" ] = "bytes=%d-" % (resume_size ,)
170+ r = requests .get (url , stream = True , proxies = proxies , headers = headers )
171+ r .raise_for_status ()
172+ for chunk in r .iter_content (chunk_size = 1024 ):
173+ if chunk : # filter out keep-alive new chunks
174+ temp_file .write (chunk )
175+
176+
177+ def hf_get_from_cache (
178+ url : str ,
179+ cache_dir : str ,
180+ force_download = False ,
181+ proxies = None ,
182+ etag_timeout = 10 ,
183+ resume_download = False ,
184+ user_agent : Union [Dict , str , None ] = None ,
185+ local_files_only = False ,
186+ ) -> Optional [str ]: # pragma: no cover
187+ """
188+ Given a URL, look for the corresponding file in the local cache. If it's not there, download it. Then return the
189+ path to the cached file.
190+
191+ Return:
192+ Local path (string) of file or if networking is off, last version of file cached on disk.
193+
194+ Raises:
195+ In case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
196+ """
197+
198+ os .makedirs (cache_dir , exist_ok = True )
199+
200+ url_to_download = url
201+ etag = None
202+ if not local_files_only :
203+ try :
204+ headers = {"user-agent" : http_user_agent (user_agent )}
205+ r = requests .head (
206+ url , headers = headers , allow_redirects = False , proxies = proxies , timeout = etag_timeout
207+ )
208+ r .raise_for_status ()
209+ etag = r .headers .get ("X-Linked-Etag" ) or r .headers .get ("ETag" )
210+ # We favor a custom header indicating the etag of the linked resource, and
211+ # we fallback to the regular etag header.
212+ # If we don't have any of those, raise an error.
213+ if etag is None :
214+ raise OSError (
215+ "Distant resource does not have an ETag, we won't be able to reliably ensure reproducibility."
216+ )
217+ # In case of a redirect,
218+ # save an extra redirect on the request.get call,
219+ # and ensure we download the exact atomic version even if it changed
220+ # between the HEAD and the GET (unlikely, but hey).
221+ if 300 <= r .status_code <= 399 :
222+ url_to_download = r .headers ["Location" ]
223+ except (requests .exceptions .ConnectionError , requests .exceptions .Timeout ):
224+ # etag is already None
225+ pass
226+
227+ filename = hf_url_to_filename (url , etag )
228+
229+ # get cache path to put the file
230+ cache_path = os .path .join (cache_dir , filename )
231+
232+ # etag is None == we don't have a connection or we passed local_files_only.
233+ # try to get the last downloaded one
234+ if etag is None :
235+ if os .path .exists (cache_path ):
236+ return cache_path
237+ else :
238+ matching_files = [
239+ file
240+ for file in fnmatch .filter (os .listdir (cache_dir ), filename .split ("." )[0 ] + ".*" )
241+ if not file .endswith (".json" ) and not file .endswith (".lock" )
242+ ]
243+ if len (matching_files ) > 0 :
244+ return os .path .join (cache_dir , matching_files [- 1 ])
245+ else :
246+ # If files cannot be found and local_files_only=True,
247+ # the models might've been found if local_files_only=False
248+ # Notify the user about that
249+ if local_files_only :
250+ raise ValueError (
251+ "Cannot find the requested files in the cached path and outgoing traffic has been"
252+ " disabled. To enable model look-ups and downloads online, set 'local_files_only'"
253+ " to False."
254+ )
255+ else :
256+ raise ValueError (
257+ "Connection error, and we cannot find the requested files in the cached path."
258+ " Please try again or make sure your Internet connection is on."
259+ )
260+
261+ # From now on, etag is not None.
262+ if os .path .exists (cache_path ) and not force_download :
263+ return cache_path
264+
265+ # Prevent parallel downloads of the same file with a lock.
266+ lock_path = cache_path + ".lock"
267+ with FileLock (lock_path ):
268+
269+ # If the download just completed while the lock was activated.
270+ if os .path .exists (cache_path ) and not force_download :
271+ # Even if returning early like here, the lock will be released.
272+ return cache_path
273+
274+ if resume_download :
275+ incomplete_path = cache_path + ".incomplete"
276+
277+ @contextmanager
278+ def _resumable_file_manager () -> "io.BufferedWriter" :
279+ with open (incomplete_path , "ab" ) as f :
280+ yield f
281+
282+ temp_file_manager = _resumable_file_manager
283+ if os .path .exists (incomplete_path ):
284+ resume_size = os .stat (incomplete_path ).st_size
285+ else :
286+ resume_size = 0
287+ else :
288+ temp_file_manager = partial (
289+ tempfile .NamedTemporaryFile , mode = "wb" , dir = cache_dir , delete = False
290+ )
291+ resume_size = 0
292+
293+ # Download to temporary file, then copy to cache dir once finished.
294+ # Otherwise you get corrupt cache entries if the download gets interrupted.
295+ with temp_file_manager () as temp_file :
296+ http_get (
297+ url_to_download ,
298+ temp_file ,
299+ proxies = proxies ,
300+ resume_size = resume_size ,
301+ user_agent = user_agent ,
302+ )
303+
304+ os .replace (temp_file .name , cache_path )
305+
306+ meta = {"url" : url , "etag" : etag }
307+ meta_path = cache_path + ".json"
308+ with open (meta_path , "w" ) as meta_file :
309+ json .dump (meta , meta_file )
310+
311+ return cache_path
0 commit comments