Skip to content
4 changes: 1 addition & 3 deletions src/datasets/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,6 @@ def init_dynamic_modules(name: str, hf_modules_cache: Optional[Union[Path, str]]
hf_modules_cache = init_hf_modules(hf_modules_cache)
dynamic_modules_path = os.path.join(hf_modules_cache, name)
os.makedirs(dynamic_modules_path, exist_ok=True)
if not os.path.exists(os.path.join(dynamic_modules_path, "__init__.py")):
with open(os.path.join(dynamic_modules_path, "__init__.py"), "w"):
pass
return dynamic_modules_path


Expand Down Expand Up @@ -619,6 +616,7 @@ def load_metric(
)

# Download and prepare resources for the metric
import pdb;pdb.set_trace()
metric.download_and_prepare(download_config=download_config)

return metric
Expand Down
2 changes: 1 addition & 1 deletion src/datasets/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def __init__(

self.keep_in_memory = keep_in_memory
self._data_dir_root = os.path.expanduser(cache_dir or config.HF_METRICS_CACHE)
self.data_dir = self._build_data_dir()
self.data_dir = "\\\\?\\" + self._build_data_dir()
self.seed: int = seed or np.random.get_state()[1][0]
self.timeout: Union[int, float] = timeout

Expand Down
9 changes: 3 additions & 6 deletions src/datasets/utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
INCOMPLETE_SUFFIX = ".incomplete"


def init_hf_modules(hf_modules_cache: Optional[Union[Path, str]] = None) -> str:
def init_hf_modules(hf_modules_cache: Optional[Union[Path, str]] = None, original_sys_path=sys.path) -> str:
"""
Add hf_modules_cache to the python path.
By default hf_modules_cache='~/.cache/huggingface/modules'.
Expand All @@ -51,12 +51,9 @@ def init_hf_modules(hf_modules_cache: Optional[Union[Path, str]] = None) -> str:
hf_modules_cache = hf_modules_cache if hf_modules_cache is not None else config.HF_MODULES_CACHE
hf_modules_cache = str(hf_modules_cache)
if hf_modules_cache not in sys.path:
sys.path = original_sys_path[:]
sys.path.append(hf_modules_cache)

os.makedirs(hf_modules_cache, exist_ok=True)
if not os.path.exists(os.path.join(hf_modules_cache, "__init__.py")):
with open(os.path.join(hf_modules_cache, "__init__.py"), "w"):
pass
return hf_modules_cache


Expand Down Expand Up @@ -633,7 +630,7 @@ def get_from_cache(
# Prevent parallel downloads of the same file with a lock.
lock_path = cache_path + ".lock"
with FileLock(lock_path):

cache_path = "\\\\?\\" + cache_path
if resume_download:
incomplete_path = cache_path + ".incomplete"

Expand Down
3 changes: 3 additions & 0 deletions src/datasets/utils/filelock.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,9 @@ class WindowsFileLock(BaseFileLock):
Uses the :func:`msvcrt.locking` function to hard lock the lock file on
windows systems.
"""
def __init__(self, lock_file, timeout=-1):
lock_file = "\\\\?\\" + lock_file if os.path.isabs(lock_file) else lock_file
super().__init__(lock_file, timeout=timeout)

def _acquire(self):
open_mode = os.O_RDWR | os.O_CREAT | os.O_TRUNC
Expand Down
8 changes: 7 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,23 @@

from .s3_fixtures import * # noqa: load s3 fixtures

import shutil


@pytest.fixture(autouse=True)
def set_test_cache_config(tmp_path_factory, monkeypatch):
# test_hf_cache_home = tmp_path_factory.mktemp("cache") # TODO: why a cache dir per test function does not work?
# A cache dir per test function
# test_hf_cache_home = tmp_path_factory.mktemp("cache")
# test_hf_cache_home = tmp_path / "cache"
test_hf_cache_home = tmp_path_factory.getbasetemp() / "cache"
test_hf_datasets_cache = str(test_hf_cache_home / "datasets")
test_hf_metrics_cache = str(test_hf_cache_home / "metrics")
test_hf_modules_cache = str(test_hf_cache_home / "modules")
monkeypatch.setattr("datasets.config.HF_DATASETS_CACHE", test_hf_datasets_cache)
monkeypatch.setattr("datasets.config.HF_METRICS_CACHE", test_hf_metrics_cache)
monkeypatch.setattr("datasets.config.HF_MODULES_CACHE", test_hf_modules_cache)
# yield
# shutil.rmtree(tmp_path)


FILE_CONTENT = """\
Expand Down
30 changes: 6 additions & 24 deletions tests/test_load.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import importlib
import os
import shutil
import tempfile
import time
from hashlib import sha256
Expand Down Expand Up @@ -73,13 +72,6 @@ class LoadTest(TestCase):
def inject_fixtures(self, caplog):
self._caplog = caplog

def setUp(self):
self.hf_modules_cache = tempfile.mkdtemp()
self.dynamic_modules_path = datasets.load.init_dynamic_modules("test_datasets_modules", self.hf_modules_cache)

def tearDown(self):
shutil.rmtree(self.hf_modules_cache)

def _dummy_module_dir(self, modules_dir, dummy_module_name, dummy_code):
assert dummy_module_name.startswith("__")
module_dir = os.path.join(modules_dir, dummy_module_name)
Expand All @@ -94,9 +86,7 @@ def test_prepare_module(self):
# prepare module from directory path
dummy_code = "MY_DUMMY_VARIABLE = 'hello there'"
module_dir = self._dummy_module_dir(tmp_dir, "__dummy_module_name1__", dummy_code)
importable_module_path, module_hash = datasets.load.prepare_module(
module_dir, dynamic_modules_path=self.dynamic_modules_path
)
importable_module_path, module_hash = datasets.load.prepare_module(module_dir)
dummy_module = importlib.import_module(importable_module_path)
self.assertEqual(dummy_module.MY_DUMMY_VARIABLE, "hello there")
self.assertEqual(module_hash, sha256(dummy_code.encode("utf-8")).hexdigest())
Expand All @@ -105,7 +95,7 @@ def test_prepare_module(self):
module_dir = self._dummy_module_dir(tmp_dir, "__dummy_module_name1__", dummy_code)
module_path = os.path.join(module_dir, "__dummy_module_name1__.py")
importable_module_path, module_hash, resolved_file_path = datasets.load.prepare_module(
module_path, dynamic_modules_path=self.dynamic_modules_path, return_resolved_file_path=True
module_path, return_resolved_file_path=True
)
self.assertEqual(resolved_file_path, module_path)
dummy_module = importlib.import_module(importable_module_path)
Expand All @@ -115,30 +105,22 @@ def test_prepare_module(self):
for offline_simulation_mode in list(OfflineSimulationMode):
with offline(offline_simulation_mode):
with self.assertRaises((FileNotFoundError, ConnectionError, requests.exceptions.ConnectionError)):
datasets.load.prepare_module(
"__missing_dummy_module_name__", dynamic_modules_path=self.dynamic_modules_path
)
datasets.load.prepare_module("__missing_dummy_module_name__")

def test_offline_prepare_module(self):
with tempfile.TemporaryDirectory() as tmp_dir:
dummy_code = "MY_DUMMY_VARIABLE = 'hello there'"
module_dir = self._dummy_module_dir(tmp_dir, "__dummy_module_name2__", dummy_code)
importable_module_path1, _ = datasets.load.prepare_module(
module_dir, dynamic_modules_path=self.dynamic_modules_path
)
importable_module_path1, _ = datasets.load.prepare_module(module_dir)
time.sleep(0.1) # make sure there's a difference in the OS update time of the python file
dummy_code = "MY_DUMMY_VARIABLE = 'general kenobi'"
module_dir = self._dummy_module_dir(tmp_dir, "__dummy_module_name2__", dummy_code)
importable_module_path2, _ = datasets.load.prepare_module(
module_dir, dynamic_modules_path=self.dynamic_modules_path
)
importable_module_path2, _ = datasets.load.prepare_module(module_dir)
for offline_simulation_mode in list(OfflineSimulationMode):
with offline(offline_simulation_mode):
self._caplog.clear()
# allow provide the module name without an explicit path to remote or local actual file
importable_module_path3, _ = datasets.load.prepare_module(
"__dummy_module_name2__", dynamic_modules_path=self.dynamic_modules_path
)
importable_module_path3, _ = datasets.load.prepare_module("__dummy_module_name2__")
# it loads the most recent version of the module
self.assertEqual(importable_module_path2, importable_module_path3)
self.assertNotEqual(importable_module_path1, importable_module_path3)
Expand Down
67 changes: 64 additions & 3 deletions tests/test_metric_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,25 @@ def get_local_metric_names():
return [{"testcase_name": x, "metric_name": x} for x in metrics if x != "gleu"] # gleu is unfinished


@parameterized.named_parameters(get_local_metric_names())
# @parameterized.named_parameters(get_local_metric_names())
# @parameterized.named_parameters([{"testcase_name": "sari", "metric_name": "sari"}]) # Bug: Anaconda in Windows
# @parameterized.named_parameters([{"testcase_name": "glue", "metric_name": "glue"}])
@parameterized.named_parameters([{"testcase_name": "bleurt", "metric_name": "bleurt"}])
# @parameterized.named_parameters([{"testcase_name": "comet", "metric_name": "comet"}]) # comment next line: @for_all_test_methods(skip_if_dataset_requires_fairseq)
@for_all_test_methods(skip_if_dataset_requires_fairseq)
@local
class LocalMetricTest(parameterized.TestCase):
INTENSIVE_CALLS_PATCHER = {}
metric_name = None

def test_load_metric(self, metric_name):
import pdb;pdb.set_trace()
doctest.ELLIPSIS_MARKER = "[...]"
metric_module = importlib.import_module(datasets.load.prepare_module(os.path.join("metrics", metric_name))[0])
# metric_module = importlib.import_module(
# datasets.load.prepare_module(os.path.join("metrics", metric_name), dataset=False)[0]
# )
metric_module_name = datasets.load.prepare_module(os.path.join("metrics", metric_name), dataset=False)[0]
metric_module = importlib.import_module(metric_module_name)
metric = datasets.load.import_main_class(metric_module.__name__, dataset=False)
# check parameters
parameters = inspect.signature(metric._compute).parameters
Expand All @@ -79,7 +88,9 @@ def test_load_metric(self, metric_name):
@slow
def test_load_real_metric(self, metric_name):
doctest.ELLIPSIS_MARKER = "[...]"
metric_module = importlib.import_module(datasets.load.prepare_module(os.path.join("metrics", metric_name))[0])
metric_module = importlib.import_module(
datasets.load.prepare_module(os.path.join("metrics", metric_name), dataset=False)[0]
)
# run doctest
with self.use_local_metrics():
results = doctest.testmod(metric_module, verbose=True, raise_on_error=True)
Expand Down Expand Up @@ -177,3 +188,53 @@ def test_seqeval_raises_when_incorrect_scheme():
error_message = f"Scheme should be one of [IOB1, IOB2, IOE1, IOE2, IOBES, BILOU], got {wrong_scheme}"
with pytest.raises(ValueError, match=re.escape(error_message)):
metric.compute(predictions=[], references=[], scheme=wrong_scheme)


@pytest.mark.parametrize("metric_name", ["bleurt"])
def test_albert(metric_name, monkeypatch):
doctest.ELLIPSIS_MARKER = "[...]"
# metric_module = importlib.import_module(
# datasets.load.prepare_module(os.path.join("metrics", metric_name), dataset=False)[0]
# )
metric_module_name = datasets.load.prepare_module(os.path.join("metrics", metric_name), dataset=False)[0]
metric_module = importlib.import_module(metric_module_name)
metric = datasets.load.import_main_class(metric_module.__name__, dataset=False)
# check parameters
parameters = inspect.signature(metric._compute).parameters
assert "predictions" in parameters
assert "references" in parameters
assert all([p.kind != p.VAR_KEYWORD for p in parameters.values()]) # no **kwargs
# run doctest
# with self.patch_intensive_calls(metric_name, metric_module.__name__):
# with self.use_local_metrics():
# results = doctest.testmod(metric_module, verbose=True, raise_on_error=True)

original_load_metric = datasets.load_metric
def mock_load_metric(metric_name, *args, **kwargs):
print("MOCK load_metric")
return original_load_metric(os.path.join("metrics", metric_name), *args, **kwargs)
monkeypatch.setattr("datasets.load_metric", mock_load_metric)

tmp = datasets.load_metric("bleurt")
assert False

# def mock_compute(*args, **kwrags):
# return {"scores": np.array([1.03, 1.04])}
# monkeypatch.setattr("bleurt.compute", mock_compute) # AttributeError: 'module' object at bleurt has no attribute 'compute'
import tensorflow.compat.v1 as tf
from bleurt.score import Predictor

tf.flags.DEFINE_string("sv", "", "") # handle pytest cli flags

class MockedPredictor(Predictor):
def predict(self, input_dict):
assert len(input_dict["input_ids"]) == 2
return np.array([1.03, 1.04])

# mock predict_fn which is supposed to do a forward pass with a bleurt model
monkeypatch.setattr("bleurt.score._create_predictor", MockedPredictor())

results = doctest.testmod(metric_module, verbose=True, raise_on_error=True)

assert results.failed == 0
assert results.attempted > 1