diff --git a/src/datasets/load.py b/src/datasets/load.py index 40b4dfcf754..d15cf219e04 100644 --- a/src/datasets/load.py +++ b/src/datasets/load.py @@ -71,9 +71,6 @@ def init_dynamic_modules(name: str, hf_modules_cache: Optional[Union[Path, str]] hf_modules_cache = init_hf_modules(hf_modules_cache) dynamic_modules_path = os.path.join(hf_modules_cache, name) os.makedirs(dynamic_modules_path, exist_ok=True) - if not os.path.exists(os.path.join(dynamic_modules_path, "__init__.py")): - with open(os.path.join(dynamic_modules_path, "__init__.py"), "w"): - pass return dynamic_modules_path @@ -619,6 +616,7 @@ def load_metric( ) # Download and prepare resources for the metric + import pdb;pdb.set_trace() metric.download_and_prepare(download_config=download_config) return metric diff --git a/src/datasets/metric.py b/src/datasets/metric.py index 9950718cef8..5f96c6f7471 100644 --- a/src/datasets/metric.py +++ b/src/datasets/metric.py @@ -179,7 +179,7 @@ def __init__( self.keep_in_memory = keep_in_memory self._data_dir_root = os.path.expanduser(cache_dir or config.HF_METRICS_CACHE) - self.data_dir = self._build_data_dir() + self.data_dir = "\\\\?\\" + self._build_data_dir() self.seed: int = seed or np.random.get_state()[1][0] self.timeout: Union[int, float] = timeout diff --git a/src/datasets/utils/file_utils.py b/src/datasets/utils/file_utils.py index 1c37a469659..c20bb53fc45 100644 --- a/src/datasets/utils/file_utils.py +++ b/src/datasets/utils/file_utils.py @@ -41,7 +41,7 @@ INCOMPLETE_SUFFIX = ".incomplete" -def init_hf_modules(hf_modules_cache: Optional[Union[Path, str]] = None) -> str: +def init_hf_modules(hf_modules_cache: Optional[Union[Path, str]] = None, original_sys_path=sys.path) -> str: """ Add hf_modules_cache to the python path. By default hf_modules_cache='~/.cache/huggingface/modules'. @@ -51,12 +51,9 @@ def init_hf_modules(hf_modules_cache: Optional[Union[Path, str]] = None) -> str: hf_modules_cache = hf_modules_cache if hf_modules_cache is not None else config.HF_MODULES_CACHE hf_modules_cache = str(hf_modules_cache) if hf_modules_cache not in sys.path: + sys.path = original_sys_path[:] sys.path.append(hf_modules_cache) - os.makedirs(hf_modules_cache, exist_ok=True) - if not os.path.exists(os.path.join(hf_modules_cache, "__init__.py")): - with open(os.path.join(hf_modules_cache, "__init__.py"), "w"): - pass return hf_modules_cache @@ -633,7 +630,7 @@ def get_from_cache( # Prevent parallel downloads of the same file with a lock. lock_path = cache_path + ".lock" with FileLock(lock_path): - + cache_path = "\\\\?\\" + cache_path if resume_download: incomplete_path = cache_path + ".incomplete" diff --git a/src/datasets/utils/filelock.py b/src/datasets/utils/filelock.py index 5d4061d2e24..00e20a70039 100644 --- a/src/datasets/utils/filelock.py +++ b/src/datasets/utils/filelock.py @@ -332,6 +332,9 @@ class WindowsFileLock(BaseFileLock): Uses the :func:`msvcrt.locking` function to hard lock the lock file on windows systems. """ + def __init__(self, lock_file, timeout=-1): + lock_file = "\\\\?\\" + lock_file if os.path.isabs(lock_file) else lock_file + super().__init__(lock_file, timeout=timeout) def _acquire(self): open_mode = os.O_RDWR | os.O_CREAT | os.O_TRUNC diff --git a/tests/conftest.py b/tests/conftest.py index 4f110012a69..cd4a1f030e3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,10 +10,14 @@ from .s3_fixtures import * # noqa: load s3 fixtures +import shutil + @pytest.fixture(autouse=True) def set_test_cache_config(tmp_path_factory, monkeypatch): - # test_hf_cache_home = tmp_path_factory.mktemp("cache") # TODO: why a cache dir per test function does not work? + # A cache dir per test function + # test_hf_cache_home = tmp_path_factory.mktemp("cache") + # test_hf_cache_home = tmp_path / "cache" test_hf_cache_home = tmp_path_factory.getbasetemp() / "cache" test_hf_datasets_cache = str(test_hf_cache_home / "datasets") test_hf_metrics_cache = str(test_hf_cache_home / "metrics") @@ -21,6 +25,8 @@ def set_test_cache_config(tmp_path_factory, monkeypatch): monkeypatch.setattr("datasets.config.HF_DATASETS_CACHE", test_hf_datasets_cache) monkeypatch.setattr("datasets.config.HF_METRICS_CACHE", test_hf_metrics_cache) monkeypatch.setattr("datasets.config.HF_MODULES_CACHE", test_hf_modules_cache) + # yield + # shutil.rmtree(tmp_path) FILE_CONTENT = """\ diff --git a/tests/test_load.py b/tests/test_load.py index 7e5215f7846..5cd8f782a4f 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -1,6 +1,5 @@ import importlib import os -import shutil import tempfile import time from hashlib import sha256 @@ -73,13 +72,6 @@ class LoadTest(TestCase): def inject_fixtures(self, caplog): self._caplog = caplog - def setUp(self): - self.hf_modules_cache = tempfile.mkdtemp() - self.dynamic_modules_path = datasets.load.init_dynamic_modules("test_datasets_modules", self.hf_modules_cache) - - def tearDown(self): - shutil.rmtree(self.hf_modules_cache) - def _dummy_module_dir(self, modules_dir, dummy_module_name, dummy_code): assert dummy_module_name.startswith("__") module_dir = os.path.join(modules_dir, dummy_module_name) @@ -94,9 +86,7 @@ def test_prepare_module(self): # prepare module from directory path dummy_code = "MY_DUMMY_VARIABLE = 'hello there'" module_dir = self._dummy_module_dir(tmp_dir, "__dummy_module_name1__", dummy_code) - importable_module_path, module_hash = datasets.load.prepare_module( - module_dir, dynamic_modules_path=self.dynamic_modules_path - ) + importable_module_path, module_hash = datasets.load.prepare_module(module_dir) dummy_module = importlib.import_module(importable_module_path) self.assertEqual(dummy_module.MY_DUMMY_VARIABLE, "hello there") self.assertEqual(module_hash, sha256(dummy_code.encode("utf-8")).hexdigest()) @@ -105,7 +95,7 @@ def test_prepare_module(self): module_dir = self._dummy_module_dir(tmp_dir, "__dummy_module_name1__", dummy_code) module_path = os.path.join(module_dir, "__dummy_module_name1__.py") importable_module_path, module_hash, resolved_file_path = datasets.load.prepare_module( - module_path, dynamic_modules_path=self.dynamic_modules_path, return_resolved_file_path=True + module_path, return_resolved_file_path=True ) self.assertEqual(resolved_file_path, module_path) dummy_module = importlib.import_module(importable_module_path) @@ -115,30 +105,22 @@ def test_prepare_module(self): for offline_simulation_mode in list(OfflineSimulationMode): with offline(offline_simulation_mode): with self.assertRaises((FileNotFoundError, ConnectionError, requests.exceptions.ConnectionError)): - datasets.load.prepare_module( - "__missing_dummy_module_name__", dynamic_modules_path=self.dynamic_modules_path - ) + datasets.load.prepare_module("__missing_dummy_module_name__") def test_offline_prepare_module(self): with tempfile.TemporaryDirectory() as tmp_dir: dummy_code = "MY_DUMMY_VARIABLE = 'hello there'" module_dir = self._dummy_module_dir(tmp_dir, "__dummy_module_name2__", dummy_code) - importable_module_path1, _ = datasets.load.prepare_module( - module_dir, dynamic_modules_path=self.dynamic_modules_path - ) + importable_module_path1, _ = datasets.load.prepare_module(module_dir) time.sleep(0.1) # make sure there's a difference in the OS update time of the python file dummy_code = "MY_DUMMY_VARIABLE = 'general kenobi'" module_dir = self._dummy_module_dir(tmp_dir, "__dummy_module_name2__", dummy_code) - importable_module_path2, _ = datasets.load.prepare_module( - module_dir, dynamic_modules_path=self.dynamic_modules_path - ) + importable_module_path2, _ = datasets.load.prepare_module(module_dir) for offline_simulation_mode in list(OfflineSimulationMode): with offline(offline_simulation_mode): self._caplog.clear() # allow provide the module name without an explicit path to remote or local actual file - importable_module_path3, _ = datasets.load.prepare_module( - "__dummy_module_name2__", dynamic_modules_path=self.dynamic_modules_path - ) + importable_module_path3, _ = datasets.load.prepare_module("__dummy_module_name2__") # it loads the most recent version of the module self.assertEqual(importable_module_path2, importable_module_path3) self.assertNotEqual(importable_module_path1, importable_module_path3) diff --git a/tests/test_metric_common.py b/tests/test_metric_common.py index 31313f52a0b..b761d263bed 100644 --- a/tests/test_metric_common.py +++ b/tests/test_metric_common.py @@ -53,7 +53,11 @@ def get_local_metric_names(): return [{"testcase_name": x, "metric_name": x} for x in metrics if x != "gleu"] # gleu is unfinished -@parameterized.named_parameters(get_local_metric_names()) +# @parameterized.named_parameters(get_local_metric_names()) +# @parameterized.named_parameters([{"testcase_name": "sari", "metric_name": "sari"}]) # Bug: Anaconda in Windows +# @parameterized.named_parameters([{"testcase_name": "glue", "metric_name": "glue"}]) +@parameterized.named_parameters([{"testcase_name": "bleurt", "metric_name": "bleurt"}]) +# @parameterized.named_parameters([{"testcase_name": "comet", "metric_name": "comet"}]) # comment next line: @for_all_test_methods(skip_if_dataset_requires_fairseq) @for_all_test_methods(skip_if_dataset_requires_fairseq) @local class LocalMetricTest(parameterized.TestCase): @@ -61,8 +65,13 @@ class LocalMetricTest(parameterized.TestCase): metric_name = None def test_load_metric(self, metric_name): + import pdb;pdb.set_trace() doctest.ELLIPSIS_MARKER = "[...]" - metric_module = importlib.import_module(datasets.load.prepare_module(os.path.join("metrics", metric_name))[0]) + # metric_module = importlib.import_module( + # datasets.load.prepare_module(os.path.join("metrics", metric_name), dataset=False)[0] + # ) + metric_module_name = datasets.load.prepare_module(os.path.join("metrics", metric_name), dataset=False)[0] + metric_module = importlib.import_module(metric_module_name) metric = datasets.load.import_main_class(metric_module.__name__, dataset=False) # check parameters parameters = inspect.signature(metric._compute).parameters @@ -79,7 +88,9 @@ def test_load_metric(self, metric_name): @slow def test_load_real_metric(self, metric_name): doctest.ELLIPSIS_MARKER = "[...]" - metric_module = importlib.import_module(datasets.load.prepare_module(os.path.join("metrics", metric_name))[0]) + metric_module = importlib.import_module( + datasets.load.prepare_module(os.path.join("metrics", metric_name), dataset=False)[0] + ) # run doctest with self.use_local_metrics(): results = doctest.testmod(metric_module, verbose=True, raise_on_error=True) @@ -177,3 +188,53 @@ def test_seqeval_raises_when_incorrect_scheme(): error_message = f"Scheme should be one of [IOB1, IOB2, IOE1, IOE2, IOBES, BILOU], got {wrong_scheme}" with pytest.raises(ValueError, match=re.escape(error_message)): metric.compute(predictions=[], references=[], scheme=wrong_scheme) + + +@pytest.mark.parametrize("metric_name", ["bleurt"]) +def test_albert(metric_name, monkeypatch): + doctest.ELLIPSIS_MARKER = "[...]" + # metric_module = importlib.import_module( + # datasets.load.prepare_module(os.path.join("metrics", metric_name), dataset=False)[0] + # ) + metric_module_name = datasets.load.prepare_module(os.path.join("metrics", metric_name), dataset=False)[0] + metric_module = importlib.import_module(metric_module_name) + metric = datasets.load.import_main_class(metric_module.__name__, dataset=False) + # check parameters + parameters = inspect.signature(metric._compute).parameters + assert "predictions" in parameters + assert "references" in parameters + assert all([p.kind != p.VAR_KEYWORD for p in parameters.values()]) # no **kwargs + # run doctest + # with self.patch_intensive_calls(metric_name, metric_module.__name__): + # with self.use_local_metrics(): + # results = doctest.testmod(metric_module, verbose=True, raise_on_error=True) + + original_load_metric = datasets.load_metric + def mock_load_metric(metric_name, *args, **kwargs): + print("MOCK load_metric") + return original_load_metric(os.path.join("metrics", metric_name), *args, **kwargs) + monkeypatch.setattr("datasets.load_metric", mock_load_metric) + + tmp = datasets.load_metric("bleurt") + assert False + + # def mock_compute(*args, **kwrags): + # return {"scores": np.array([1.03, 1.04])} + # monkeypatch.setattr("bleurt.compute", mock_compute) # AttributeError: 'module' object at bleurt has no attribute 'compute' + import tensorflow.compat.v1 as tf + from bleurt.score import Predictor + + tf.flags.DEFINE_string("sv", "", "") # handle pytest cli flags + + class MockedPredictor(Predictor): + def predict(self, input_dict): + assert len(input_dict["input_ids"]) == 2 + return np.array([1.03, 1.04]) + + # mock predict_fn which is supposed to do a forward pass with a bleurt model + monkeypatch.setattr("bleurt.score._create_predictor", MockedPredictor()) + + results = doctest.testmod(metric_module, verbose=True, raise_on_error=True) + + assert results.failed == 0 + assert results.attempted > 1